Orchard Controller: implement an SSH server that acts as a jump host (#179)

* proxy.Connections(): require io.ReadWriteCloser instead of net.Conn

* Orchard Controller: implement an SSH server that acts as a jump host

* Issue a warning if the name used will be invalid in the future

* Further restrict uppercase characters in names in the future

The rationale is similar to https://github.com/kubernetes/kubernetes/issues/71140.

We won't want to munge the user's input and introduce subtle bugs doing
lowercase comparisons.
This commit is contained in:
Nikolay Edigaryev 2024-06-11 19:32:45 +04:00 committed by GitHub
parent c845f3b2fd
commit d59bc7f8a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 763 additions and 152 deletions

View File

@ -2,8 +2,8 @@ package main
import (
"context"
"fmt"
"github.com/cirruslabs/orchard/internal/command"
"log"
"os"
"os/signal"
)
@ -25,6 +25,8 @@ func main() {
// Run the command
if err := command.NewRootCmd().ExecuteContext(ctx); err != nil {
log.Fatal(err)
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}

8
go.mod
View File

@ -26,10 +26,10 @@ require (
github.com/spf13/cobra v1.6.0
github.com/stretchr/testify v1.8.1
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.21.0
golang.org/x/crypto v0.23.0
golang.org/x/exp v0.0.0-20230321023759-10a507213a29
golang.org/x/net v0.23.0
golang.org/x/term v0.18.0
golang.org/x/term v0.20.0
google.golang.org/grpc v1.56.3
google.golang.org/protobuf v1.30.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
@ -102,8 +102,8 @@ require (
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)

16
go.sum
View File

@ -497,8 +497,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -629,12 +629,12 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8=
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -643,8 +643,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View File

@ -12,7 +12,10 @@ import (
"fmt"
"github.com/cirruslabs/orchard/internal/controller"
"github.com/cirruslabs/orchard/internal/netconstants"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
"math/big"
"os"
"time"
)
@ -20,25 +23,80 @@ var ErrInitFailed = errors.New("controller initialization failed")
var controllerCertPath string
var controllerKeyPath string
var sshHostKeyPath string
func FindControllerCertificate(dataDir *controller.DataDir) (controllerCert tls.Certificate, err error) {
func FindControllerCertificate(dataDir *controller.DataDir) (tls.Certificate, error) {
// Prefer user-specified certificate and key
if controllerCertPath != "" || controllerKeyPath != "" {
// if external certificate is specified, use it
if err = checkBothCertAndKeyAreSpecified(); err != nil {
return controllerCert, err
if err := checkBothCertAndKeyAreSpecified(); err != nil {
return tls.Certificate{}, err
}
return tls.LoadX509KeyPair(controllerCertPath, controllerKeyPath)
} else if !dataDir.ControllerCertificateExists() {
// otherwise, generate a self-signed certificate if it's not already present
controllerCert, err = GenerateSelfSignedControllerCertificate()
if err != nil {
return controllerCert, err
}
if err = dataDir.SetControllerCertificate(controllerCert); err != nil {
return controllerCert, err
}
}
return
// Fall back to loading the certificate from the Orchard data directory
controllerCert, err := dataDir.ControllerCertificate()
if err == nil {
return controllerCert, nil
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return tls.Certificate{}, err
}
// Fall back to generating a new certificate
controllerCert, err = GenerateSelfSignedControllerCertificate()
if err != nil {
return controllerCert, err
}
if err = dataDir.SetControllerCertificate(controllerCert); err != nil {
return controllerCert, err
}
return controllerCert, nil
}
func FindSSHHostKey(dataDir *controller.DataDir) (ssh.Signer, error) {
// Prefer user-specified host key
if sshHostKeyPath != "" {
hostKeyBytes, err := os.ReadFile(sshHostKeyPath)
if err != nil {
return nil, err
}
signer, err := ssh.ParsePrivateKey(hostKeyBytes)
if err != nil {
return nil, err
}
return signer, nil
}
// Fall back to loading the host key from the Orchard data directory
signer, err := dataDir.SSHHostKey()
if err == nil {
return signer, err
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, err
}
// Fall back to generating a new host key
_, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
return nil, err
}
if err := dataDir.SetSSHHostKey(privateKey); err != nil {
return nil, err
}
signer, err = ssh.NewSignerFromKey(privateKey)
if err != nil {
return nil, err
}
return signer, nil
}
func checkBothCertAndKeyAreSpecified() error {

View File

@ -17,6 +17,7 @@ var ErrRunFailed = errors.New("failed to run controller")
var BootstrapAdminAccountName = "bootstrap-admin"
var address string
var addressSSH string
var debug bool
func newRunCommand() *cobra.Command {
@ -31,7 +32,10 @@ func newRunCommand() *cobra.Command {
port = strconv.FormatInt(netconstants.DefaultControllerPort, 10)
}
cmd.PersistentFlags().StringVarP(&address, "listen", "l", fmt.Sprintf(":%s", port), "address to listen on")
cmd.PersistentFlags().StringVarP(&address, "listen", "l", fmt.Sprintf(":%s", port),
"address to listen on")
cmd.PersistentFlags().StringVar(&addressSSH, "listen-ssh", "",
"address for the built-in SSH server to listen on (e.g. \":6122\")")
cmd.PersistentFlags().BoolVar(&debug, "debug", false, "enable debug logging")
// flags for auto-init if necessary
@ -42,6 +46,8 @@ func newRunCommand() *cobra.Command {
cmd.PersistentFlags().StringVar(&controllerKeyPath, "controller-key", "",
"use the controller certificate key from the specified path instead of the auto-generated one"+
" (requires --controller-cert)")
cmd.PersistentFlags().StringVar(&sshHostKeyPath, "ssh-host-key", "",
"use the SSH private host key from the specified path instead of the auto-generated one")
return cmd
}
@ -74,20 +80,12 @@ func runController(cmd *cobra.Command, args []string) (err error) {
return err
}
var controllerCert tls.Certificate
if dataDir.ControllerCertificateExists() {
controllerCert, err = dataDir.ControllerCertificate()
if err != nil {
return err
}
} else {
controllerCert, err = FindControllerCertificate(dataDir)
if err != nil {
return err
}
controllerCert, err := FindControllerCertificate(dataDir)
if err != nil {
return err
}
controllerInstance, err := controller.New(
controllerOpts := []controller.Option{
controller.WithListenAddr(address),
controller.WithDataDir(dataDir),
controller.WithLogger(logger),
@ -97,7 +95,18 @@ func runController(cmd *cobra.Command, args []string) (err error) {
controllerCert,
},
}),
)
}
if addressSSH != "" {
signer, err := FindSSHHostKey(dataDir)
if err != nil {
return err
}
controllerOpts = append(controllerOpts, controller.WithSSHServer(addressSSH, signer))
}
controllerInstance, err := controller.New(controllerOpts...)
if err != nil {
return err
}

View File

@ -3,6 +3,7 @@ package create
import (
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/simplename"
"github.com/cirruslabs/orchard/pkg/client"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/spf13/cobra"
@ -66,6 +67,11 @@ func newCreateVMCommand() *cobra.Command {
func runCreateVM(cmd *cobra.Command, args []string) error {
name := args[0]
// Issue a warning if the name used will be invalid in the future
if err := simplename.ValidateNext(name); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "WARNING: %v\n", err)
}
// Convert arguments
var hostDirs []v1.HostDir

View File

@ -8,12 +8,14 @@ import (
"github.com/cirruslabs/orchard/internal/controller/notifier"
"github.com/cirruslabs/orchard/internal/controller/proxy"
"github.com/cirruslabs/orchard/internal/controller/scheduler"
"github.com/cirruslabs/orchard/internal/controller/sshserver"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
"github.com/cirruslabs/orchard/internal/controller/store/badger"
"github.com/cirruslabs/orchard/internal/netconstants"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
@ -53,6 +55,10 @@ type Controller struct {
workerOfflineTimeout time.Duration
maxWorkersPerLicense uint
sshListenAddr string
sshSigner ssh.Signer
sshServer *sshserver.SSHServer
rpc.UnimplementedControllerServer
}
@ -94,16 +100,30 @@ func New(opts ...Option) (*Controller, error) {
controller.logger = zap.NewNop().Sugar()
}
// Instantiate controller
// Instantiate the database
store, err := badger.NewBadgerStore(controller.dataDir.DBPath())
if err != nil {
return nil, err
}
controller.store = store
// Instantiate the worker notifier
controller.workerNotifier = notifier.NewNotifier(controller.logger.With("component", "rpc"))
// Instantiate the scheduler
controller.scheduler = scheduler.NewScheduler(store, controller.workerNotifier,
controller.workerOfflineTimeout, controller.logger)
// Instantiate the SSH server (if configured)
if controller.sshListenAddr != "" && controller.sshSigner != nil {
controller.sshServer, err = sshserver.NewSSHServer(controller.sshListenAddr, controller.sshSigner,
store, controller.proxy, controller.workerNotifier, controller.logger)
if err != nil {
return nil, err
}
}
// Instantiate the controller
listener, err := net.Listen("tcp", controller.listenAddr)
if err != nil {
return nil, err
@ -181,6 +201,11 @@ func (controller *Controller) Run(ctx context.Context) error {
// be assigned to a specific Worker
go controller.scheduler.Run()
// Run the SSH server (if configured)
if controller.sshServer != nil {
go controller.sshServer.Run()
}
// A helper function to shut down the HTTP server on context cancellation
go func() {
<-ctx.Done()
@ -206,3 +231,11 @@ func (controller *Controller) Address() string {
return fmt.Sprintf("http://%s", hostPort)
}
func (controller *Controller) SSHAddress() (string, bool) {
if controller.sshServer == nil {
return "", false
}
return controller.sshServer.Address(), true
}

View File

@ -1,25 +1,25 @@
package controller
import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
"os"
"path/filepath"
)
var ErrDataDirError = errors.New("controller's data directory operation error")
type DataDir struct {
path string
}
func NewDataDir(path string) (*DataDir, error) {
if err := os.MkdirAll(path, 0700); err != nil {
return nil, fmt.Errorf("%w: failed to create data directory at path %s: %v",
ErrDataDirError, path, err)
return nil, fmt.Errorf("failed to create data directory at path %s: %w",
path, err)
}
return &DataDir{
@ -30,8 +30,7 @@ func NewDataDir(path string) (*DataDir, error) {
func (dataDir *DataDir) ControllerCertificate() (tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(dataDir.ControllerCertificatePath(), dataDir.ControllerKeyPath())
if err != nil {
return tls.Certificate{}, fmt.Errorf("%w: failed to load controller's certificate and key: %v",
ErrDataDirError, err)
return tls.Certificate{}, fmt.Errorf("failed to load controller's certificate and key: %w", err)
}
return cert, nil
@ -45,8 +44,8 @@ func (dataDir *DataDir) SetControllerCertificate(certificate tls.Certificate) er
privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(certificate.PrivateKey)
if err != nil {
return fmt.Errorf("%w: failed to set controller's certificate: PKCS #8 marshalling failed: %v",
ErrDataDirError, err)
return fmt.Errorf("failed to set controller's certificate: PKCS #8 marshalling failed: %w",
err)
}
privateKeyPEMBytes := pem.EncodeToMemory(&pem.Block{
@ -56,29 +55,38 @@ func (dataDir *DataDir) SetControllerCertificate(certificate tls.Certificate) er
err = os.WriteFile(dataDir.ControllerCertificatePath(), certPEMBytes, 0600)
if err != nil {
return fmt.Errorf("%w: failed to write controller's certificate: %v", ErrDataDirError, err)
return fmt.Errorf("failed to write controller's certificate: %w", err)
}
err = os.WriteFile(dataDir.ControllerKeyPath(), privateKeyPEMBytes, 0600)
if err != nil {
return fmt.Errorf("%w: failed to write controller's key: %v", ErrDataDirError, err)
return fmt.Errorf("failed to write controller's key: %w", err)
}
return nil
}
func (dataDir *DataDir) SSHHostKey() (ssh.Signer, error) {
hostKeyBytes, err := os.ReadFile(dataDir.SSHHostKeyPath())
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(hostKeyBytes)
}
func (dataDir *DataDir) SetSSHHostKey(privateKey crypto.PrivateKey) error {
pemBlock, err := ssh.MarshalPrivateKey(privateKey, "")
if err != nil {
return err
}
return os.WriteFile(dataDir.SSHHostKeyPath(), pem.EncodeToMemory(pemBlock), 0600)
}
func (dataDir *DataDir) DBPath() string {
return filepath.Join(dataDir.path, "db")
}
func (dataDir *DataDir) ControllerCertificateExists() bool {
return fileExist(dataDir.ControllerCertificatePath()) && fileExist(dataDir.ControllerKeyPath())
}
func fileExist(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func (dataDir *DataDir) ControllerCertificatePath() string {
return filepath.Join(dataDir.path, "controller.crt")
}
@ -87,6 +95,10 @@ func (dataDir *DataDir) ControllerKeyPath() string {
return filepath.Join(dataDir.path, "controller.key")
}
func (dataDir *DataDir) SSHHostKeyPath() string {
return filepath.Join(dataDir.path, "ssh_host_ed25519_key")
}
func (dataDir *DataDir) Initialized() (bool, error) {
dataDirEntries, err := os.ReadDir(dataDir.path)
if err != nil {
@ -94,8 +106,8 @@ func (dataDir *DataDir) Initialized() (bool, error) {
return false, nil
}
return false, fmt.Errorf("%w: failed to read data directory contents at path %s: %v",
ErrDataDirError, dataDir.path, err)
return false, fmt.Errorf("failed to read data directory contents at path %s: %w",
dataDir.path, err)
}
return len(dataDirEntries) != 0, nil

View File

@ -3,6 +3,7 @@ package controller
import (
"crypto/tls"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"time"
)
@ -26,6 +27,13 @@ func WithTLSConfig(tlsConfig *tls.Config) Option {
}
}
func WithSSHServer(listenAddr string, signer ssh.Signer) Option {
return func(controller *Controller) {
controller.sshListenAddr = listenAddr
controller.sshSigner = signer
}
}
func WithInsecureAuthDisabled() Option {
return func(controller *Controller) {
controller.insecureAuthDisabled = true

View File

@ -0,0 +1,284 @@
package sshserver
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/controller/notifier"
proxypkg "github.com/cirruslabs/orchard/internal/controller/proxy"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
"github.com/cirruslabs/orchard/internal/proxy"
"github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/google/uuid"
"github.com/samber/lo"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
"net"
"strings"
)
const (
// "ssh -J" uses channels of type "direct-tcpip", which are documented
// in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1].
//
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
channelTypeDirectTCPIP = "direct-tcpip"
)
type SSHServer struct {
listener net.Listener
serverConfig *ssh.ServerConfig
store storepkg.Store
proxy *proxypkg.Proxy
workerNotifier *notifier.Notifier
logger *zap.SugaredLogger
}
func NewSSHServer(
address string,
signer ssh.Signer,
store storepkg.Store,
proxy *proxypkg.Proxy,
workerNotifier *notifier.Notifier,
logger *zap.SugaredLogger,
) (*SSHServer, error) {
server := &SSHServer{
store: store,
proxy: proxy,
workerNotifier: workerNotifier,
logger: logger,
}
listener, err := net.Listen("tcp", address)
if err != nil {
return nil, err
}
server.listener = listener
server.serverConfig = &ssh.ServerConfig{
PasswordCallback: server.passwordCallback,
}
server.serverConfig.AddHostKey(signer)
return server, nil
}
func (server *SSHServer) Run() {
for {
conn, err := server.listener.Accept()
if err != nil {
server.logger.Warnf("failed to accept connection: %v", err)
continue
}
go server.handleConnection(conn)
}
}
func (server *SSHServer) Address() string {
return strings.ReplaceAll(server.listener.Addr().String(), "[::]", "127.0.0.1")
}
func (server *SSHServer) passwordCallback(connMetadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if err := server.store.View(func(txn storepkg.Transaction) error {
// Authenticate
server.logger.Debugf("authenticating user %q using the password authentication",
connMetadata.User())
serviceAccount, err := txn.GetServiceAccount(connMetadata.User())
if err != nil {
if errors.Is(err, storepkg.ErrNotFound) {
return fmt.Errorf("authentication failed, non-existent user %q",
connMetadata.User())
}
server.logger.Errorf("failed to retrieve service account %q: %v",
connMetadata.User(), err)
return fmt.Errorf("authentication failed due to an internal error")
}
if subtle.ConstantTimeCompare([]byte(serviceAccount.Token), password) != 1 {
return fmt.Errorf("authentication failed for user %q: invalid password",
connMetadata.User())
}
// Authorize
if !lo.Contains(serviceAccount.Roles, v1.ServiceAccountRoleComputeWrite) {
return fmt.Errorf("authorization failed for user %q because it lacks %q role",
connMetadata.User(), v1.ServiceAccountRoleComputeWrite)
}
return nil
}); err != nil {
return nil, err
}
return &ssh.Permissions{}, nil
}
func (server *SSHServer) handleConnection(conn net.Conn) {
sshConn, newChannelCh, requestCh, err := ssh.NewServerConn(conn, server.serverConfig)
if err != nil {
server.logger.Warnf("failed to instantiate the SSH server instance to handle "+
"the incoming connection from %s: %v", conn.RemoteAddr().String(), err)
return
}
defer func() {
_ = sshConn.Close()
}()
server.logger.Debugf("accepted SSH connection for user %q connecting from %q",
sshConn.User(), sshConn.RemoteAddr().String())
connCtx, connCtxCancel := context.WithCancel(context.Background())
defer connCtxCancel()
for {
select {
case newChannel, ok := <-newChannelCh:
if !ok {
return
}
switch newChannel.ChannelType() {
case channelTypeDirectTCPIP:
server.logger.Debugf("handling a new direct TCP/IP channel for user %q connecting from %q",
sshConn.User(), sshConn.RemoteAddr().String())
go server.handleDirectTCPIP(connCtx, newChannel)
default:
message := fmt.Sprintf("unsupported channel type requested: %q", newChannel.ChannelType())
server.logger.Debugf(message)
if err := newChannel.Reject(ssh.UnknownChannelType, message); err != nil {
server.logger.Warnf("failed to reject new channel of unsupported type %q: %v",
newChannel.ChannelType(), err)
return
}
}
case request, ok := <-requestCh:
if !ok {
return
}
server.logger.Debugf("refusing to service new request of type %q with payload of %d bytes",
request.Type, len(request.Payload))
if err := request.Reply(false, nil); err != nil {
server.logger.Warnf("failed to reply to a new request of type %q and payload of %d bytes: %v",
request.Type, len(request.Payload), err)
return
}
}
}
}
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.NewChannel) {
// Unmarshal the payload to determine to which VM the user wants to connect to
//
// This direct TCP/IP channel's payload is documented
// in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1].
//
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
payload := struct {
HostToConnect string
PortToConnect uint32
OriginatorIPAddress string
OriginatorPort uint32
}{}
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
message := fmt.Sprintf("failed to unmarshal payload: %v", err)
server.logger.Warn(message)
if err := newChannel.Reject(ssh.ConnectionFailed, message); err != nil {
server.logger.Warnf("failed to reject the new channel: %v", err)
}
return
}
server.logger.Debugf("proxying connection to %s:%d", payload.HostToConnect, payload.PortToConnect)
// Retrieve the VM object
var vm *v1.VM
var err error
err = server.store.View(func(txn storepkg.Transaction) error {
vm, err = txn.GetVM(payload.HostToConnect)
return err
})
if err != nil {
if err := newChannel.Reject(ssh.ConnectionFailed, "failed to find VM"); err != nil {
server.logger.Warnf("failed to reject the new channel due to non-existent VM %q: %v",
payload.HostToConnect, err)
}
return
}
// The user wants to connect to an existing VM, request and wait
// for a connection with the worker before accepting the channel
session := uuid.New().String()
boomerangConnCh, cancel := server.proxy.Request(ctx, session)
defer cancel()
err = server.workerNotifier.Notify(ctx, vm.Worker, &rpc.WatchInstruction{
Action: &rpc.WatchInstruction_PortForwardAction{
PortForwardAction: &rpc.WatchInstruction_PortForward{
Session: session,
VmUid: vm.UID,
Port: payload.PortToConnect,
},
},
})
if err != nil {
server.logger.Warnf("failed to request port-forwarding from the worker %s: %v",
vm.Worker, err)
return
}
// Wait for the connection from worker and commence port forwarding
select {
case fromWorkerConnection := <-boomerangConnCh:
// Now that we have the connection from worker we can accept the channel
acceptedChannel, acceptedChannelRequests, err := newChannel.Accept()
if err != nil {
server.logger.Warnf("failed to accept the new channel: %v", err)
return
}
// Handle new requests on the accepted channel by refusing them
go func() {
req, ok := <-acceptedChannelRequests
if !ok {
return
}
if err := req.Reply(false, nil); err != nil {
server.logger.Warnf("failed to reply to the new channel request: %v", err)
return
}
}()
// Commence port forwarding
if err := proxy.Connections(acceptedChannel, fromWorkerConnection); err != nil {
server.logger.Warnf("failed to port-forward: %v", err)
}
case <-ctx.Done():
return
}
}

View File

@ -2,11 +2,10 @@ package proxy
import (
"io"
"net"
"strings"
)
func Connections(left net.Conn, right net.Conn) (finalErr error) {
func Connections(left io.ReadWriteCloser, right io.ReadWriteCloser) (finalErr error) {
leftErrCh := make(chan error, 1)
rightErrCh := make(chan error, 1)

View File

@ -2,9 +2,14 @@ package simplename
import (
"errors"
"fmt"
)
var ErrNotASimpleName = errors.New("name contains restricted characters, please only use [A-Za-z0-9:-_.]")
var (
ErrNotASimpleName = errors.New("name contains restricted characters, please only use [A-Za-z0-9:-_.]")
ErrNotASimpleNameNext = errors.New("names with characters other than [a-z0-9-] will be deprecated in the future")
)
func Validate(s string) error {
for _, ch := range s {
@ -29,3 +34,53 @@ func Validate(s string) error {
return nil
}
func ValidateNext(s string) error {
// Ensure that the name is not empty
if s == "" {
return fmt.Errorf("name cannot be empty")
}
// Ensure that the name is 63 characters or fewer
if len(s) > 63 {
return fmt.Errorf("names with more than 63 characters " +
"will be depreacted in the future")
}
// Ensure that the name starts and ends with an alphanumeric character
if !isAlphanumeric(s[0]) {
return fmt.Errorf("names not starting with an alphanumeric character " +
"will be deprecated in the future")
}
if !isAlphanumeric(s[len(s)-1]) {
return fmt.Errorf("names not ending with an alphanumeric character " +
"will be deprecated in the future")
}
for i := range s {
if isAlphanumeric(s[i]) {
continue
}
if s[i] == '-' {
continue
}
return ErrNotASimpleNameNext
}
return nil
}
func isAlphanumeric(ch uint8) bool {
if ch >= 'a' && ch <= 'z' {
return true
}
if ch >= '0' && ch <= '9' {
return true
}
return false
}

View File

@ -15,3 +15,23 @@ func TestValidate(t *testing.T) {
require.Error(t, simplename.Validate("vm%"), "special characters")
require.Error(t, simplename.Validate("😐"), "non-ASCII characters")
}
func TestValidateNext(t *testing.T) {
require.NoError(t, simplename.ValidateNext("abcdefghijklmnopqrstuvwxyz-01234567890"))
require.NoError(t, simplename.ValidateNext("vm-1"))
require.NoError(t, simplename.ValidateNext("host-local"))
require.NoError(t, simplename.ValidateNext("x"))
require.Error(t, simplename.ValidateNext("ABCDEFGHIJKLMNOPQRSTUVWXYZ"),
"uppercase characters")
require.Error(t, simplename.ValidateNext(".test"), "does not start with an alphanumeric character")
require.Error(t, simplename.ValidateNext("test."), "does not end with an alphanumeric character")
require.Error(t, simplename.ValidateNext("vm:1"), "special characters")
require.Error(t, simplename.ValidateNext("vm_1"), "special characters")
require.Error(t, simplename.ValidateNext("vm.1"), "special characters")
require.Error(t, simplename.ValidateNext("vm%"), "special characters")
require.Error(t, simplename.ValidateNext("😐"), "non-ASCII characters")
require.Error(t, simplename.ValidateNext(""), "empty name")
require.Error(t, simplename.ValidateNext("1234567890123456789012345678901234567890123456789012345678901234"),
"too long")
}

View File

@ -0,0 +1,57 @@
package devcontroller
import (
"context"
"errors"
"github.com/cirruslabs/orchard/internal/command/dev"
"github.com/cirruslabs/orchard/internal/controller"
"github.com/cirruslabs/orchard/internal/worker"
"github.com/cirruslabs/orchard/pkg/client"
"github.com/stretchr/testify/require"
"net/http"
"testing"
"time"
)
func StartIntegrationTestEnvironment(t *testing.T) (*client.Client, *controller.Controller, *worker.Worker) {
return StartIntegrationTestEnvironmentWithAdditionalOpts(t, nil, nil)
}
func StartIntegrationTestEnvironmentWithAdditionalOpts(
t *testing.T,
additionalControllerOpts []controller.Option,
additionalWorkerOpts []worker.Option,
) (*client.Client, *controller.Controller, *worker.Worker) {
t.Setenv("ORCHARD_HOME", t.TempDir())
devController, devWorker, err := dev.CreateDevControllerAndWorker(t.TempDir(),
":0", nil, additionalControllerOpts, additionalWorkerOpts)
require.NoError(t, err)
t.Cleanup(func() {
_ = devWorker.Close()
})
devContext, cancelDevFunc := context.WithCancel(context.Background())
t.Cleanup(cancelDevFunc)
go func() {
err := devController.Run(devContext)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
t.Errorf("dev controller failed: %v", err)
}
}()
go func() {
err := devWorker.Run(devContext)
if err != nil && !errors.Is(err, context.Canceled) {
t.Errorf("dev worker failed: %v", err)
}
}()
time.Sleep(5 * time.Second)
devClient, err := client.New(client.WithAddress(devController.Address()))
require.NoError(t, err)
return devClient, devController, devWorker
}

View File

@ -2,14 +2,12 @@ package tests_test
import (
"context"
"errors"
"fmt"
"github.com/cirruslabs/orchard/internal/command/dev"
"github.com/cirruslabs/orchard/internal/controller"
"github.com/cirruslabs/orchard/internal/worker"
"github.com/cirruslabs/orchard/internal/tests/devcontroller"
"github.com/cirruslabs/orchard/internal/tests/wait"
"github.com/cirruslabs/orchard/internal/worker/ondiskname"
"github.com/cirruslabs/orchard/internal/worker/tart"
"github.com/cirruslabs/orchard/pkg/client"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@ -18,7 +16,6 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
@ -28,7 +25,7 @@ import (
)
func TestSingleVM(t *testing.T) {
devClient := StartIntegrationTestEnvironment(t)
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
workers, err := devClient.Workers().List(context.Background())
if err != nil {
@ -52,7 +49,7 @@ func TestSingleVM(t *testing.T) {
if err != nil {
t.Fatal(err)
}
assert.True(t, Wait(2*time.Minute, func() bool {
assert.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err := devClient.VMs().Get(context.Background(), "test-vm")
if err != nil {
t.Fatal(err)
@ -66,7 +63,7 @@ func TestSingleVM(t *testing.T) {
}
assert.Empty(t, runningVM.StatusMessage)
assert.Equal(t, v1.VMStatusRunning, runningVM.Status)
assert.True(t, Wait(2*time.Minute, func() bool {
assert.True(t, wait.Wait(2*time.Minute, func() bool {
logLines, err := devClient.VMs().Logs(context.Background(), "test-vm")
if err != nil {
t.Fatal(err)
@ -92,7 +89,7 @@ func TestSingleVM(t *testing.T) {
require.NoError(t, devClient.VMs().Delete(context.Background(), "test-vm"))
// Ensure that the worker has deleted this VM from disk
assert.True(t, Wait(2*time.Minute, func() bool {
assert.True(t, wait.Wait(2*time.Minute, func() bool {
t.Logf("Waiting for the VM to be garbage collected...")
return !hasVMByPredicate(t, func(info tart.VMInfo) bool {
@ -102,7 +99,7 @@ func TestSingleVM(t *testing.T) {
}
func TestFailedStartupScript(t *testing.T) {
devClient := StartIntegrationTestEnvironment(t)
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
workers, err := devClient.Workers().List(context.Background())
if err != nil {
@ -125,7 +122,7 @@ func TestFailedStartupScript(t *testing.T) {
if err != nil {
t.Fatal(err)
}
assert.True(t, Wait(2*time.Minute, func() bool {
assert.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err := devClient.VMs().Get(context.Background(), "test-vm")
if err != nil {
t.Fatal(err)
@ -141,72 +138,10 @@ func TestFailedStartupScript(t *testing.T) {
"failed to run startup script: Process exited with status 123")
}
func Wait(duration time.Duration, condition func() bool) bool {
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
for {
if condition() {
// all good
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(5 * time.Second):
// try again
continue
}
}
}
func StartIntegrationTestEnvironment(
t *testing.T,
) *client.Client {
return StartIntegrationTestEnvironmentWithAdditionalOpts(t, nil, nil)
}
func StartIntegrationTestEnvironmentWithAdditionalOpts(
t *testing.T,
additionalControllerOpts []controller.Option,
additionalWorkerOpts []worker.Option,
) *client.Client {
t.Setenv("ORCHARD_HOME", t.TempDir())
devController, devWorker, err := dev.CreateDevControllerAndWorker(t.TempDir(),
":0", nil, additionalControllerOpts, additionalWorkerOpts)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = devWorker.Close()
})
devContext, cancelDevFunc := context.WithCancel(context.Background())
t.Cleanup(cancelDevFunc)
go func() {
err := devController.Run(devContext)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, http.ErrServerClosed) {
t.Errorf("dev controller failed: %v", err)
}
}()
go func() {
err := devWorker.Run(devContext)
if err != nil && !errors.Is(err, context.Canceled) {
t.Errorf("dev worker failed: %v", err)
}
}()
time.Sleep(5 * time.Second)
devClient, err := client.New(client.WithAddress(devController.Address()))
if err != nil {
t.Fatal(err)
}
return devClient
}
func TestPortForwarding(t *testing.T) {
ctx := context.Background()
devClient := StartIntegrationTestEnvironment(t)
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
// Create a generic macOS VM
err := devClient.VMs().Create(ctx, &v1.VM{
@ -261,7 +196,7 @@ func TestPortForwarding(t *testing.T) {
func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
ctx := context.Background()
devClient := StartIntegrationTestEnvironment(t)
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
const (
dummyWorkerName = "dummy-worker"
@ -304,7 +239,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
require.NoError(t, err)
// Wait for the dummy VM to get scheduled to a dummy worker
require.True(t, Wait(2*time.Minute, func() bool {
require.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err := devClient.VMs().Get(context.Background(), dummyVMName)
require.NoError(t, err)
@ -318,7 +253,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
require.NoError(t, err)
// Wait for the scheduler to change the dummy VM's status to "failed"
require.True(t, Wait(2*time.Minute, func() bool {
require.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err := devClient.VMs().Get(context.Background(), dummyVMName)
require.NoError(t, err)
@ -340,7 +275,7 @@ func TestSchedulerHealthCheckingNonExistentWorker(t *testing.T) {
func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) {
ctx := context.Background()
devClient := StartIntegrationTestEnvironmentWithAdditionalOpts(t,
devClient, _, _ := devcontroller.StartIntegrationTestEnvironmentWithAdditionalOpts(t,
[]controller.Option{controller.WithWorkerOfflineTimeout(1 * time.Minute)}, nil)
const (
@ -379,7 +314,7 @@ func TestSchedulerHealthCheckingOfflineWorker(t *testing.T) {
require.NoError(t, err)
// Wait for the VM to be marked as failed
assert.True(t, Wait(2*time.Minute, func() bool {
assert.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err := devClient.VMs().Get(context.Background(), dummyVMName)
require.NoError(t, err)
@ -415,10 +350,10 @@ func TestVMGarbageCollection(t *testing.T) {
require.True(t, hasVM(t, vmName, logger))
// Start the Orchard Worker
_ = StartIntegrationTestEnvironment(t)
_, _, _ = devcontroller.StartIntegrationTestEnvironment(t)
// Wait for the Orchard Worker to garbage-collect this VM
require.True(t, Wait(2*time.Minute, func() bool {
require.True(t, wait.Wait(2*time.Minute, func() bool {
t.Logf("Waiting for the on-disk VM to be cleaned up by the worker")
return !hasVM(t, vmName, logger)
@ -426,7 +361,7 @@ func TestVMGarbageCollection(t *testing.T) {
}
func TestHostDirs(t *testing.T) {
devClient := StartIntegrationTestEnvironment(t)
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
dirToMount := t.TempDir()
@ -461,7 +396,7 @@ func TestHostDirs(t *testing.T) {
var vm *v1.VM
require.True(t, Wait(2*time.Minute, func() bool {
require.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err = devClient.VMs().Get(context.Background(), vmName)
require.NoError(t, err)
@ -475,7 +410,7 @@ func TestHostDirs(t *testing.T) {
var logLines []string
require.True(t, Wait(2*time.Minute, func() bool {
require.True(t, wait.Wait(2*time.Minute, func() bool {
logLines, err = devClient.VMs().Logs(context.Background(), vmName)
require.NoError(t, err)
@ -495,7 +430,7 @@ func TestHostDirs(t *testing.T) {
}
func TestHostDirsInvalidPolicy(t *testing.T) {
devClient := StartIntegrationTestEnvironment(t)
devClient, _, _ := devcontroller.StartIntegrationTestEnvironment(t)
dirToMount := t.TempDir()

View File

@ -0,0 +1,109 @@
package tests_test
import (
"context"
"crypto/subtle"
"fmt"
"github.com/cirruslabs/orchard/internal/controller"
"github.com/cirruslabs/orchard/internal/tests/devcontroller"
"github.com/cirruslabs/orchard/internal/tests/wait"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
"net"
"testing"
"time"
)
func TestSSHServer(t *testing.T) {
// Generate SSH host key for the Controller
publicKey, privateKey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
sshPublicKey, err := ssh.NewPublicKey(publicKey)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(privateKey)
require.NoError(t, err)
// Run the Controller
devClient, devController, _ := devcontroller.StartIntegrationTestEnvironmentWithAdditionalOpts(t, []controller.Option{
controller.WithSSHServer(":0", signer),
}, nil)
// Create a VM to which we'll connect via Controller's SSH server
err = devClient.VMs().Create(context.Background(), &v1.VM{
Meta: v1.Meta{
Name: "test-vm",
},
Image: "ghcr.io/cirruslabs/macos-sonoma-base:latest",
CPU: 4,
Memory: 8 * 1024,
Headless: true,
Status: v1.VMStatusPending,
})
require.NoError(t, err)
// Wait for the VM to start
assert.True(t, wait.Wait(2*time.Minute, func() bool {
vm, err := devClient.VMs().Get(context.Background(), "test-vm")
require.NoError(t, err)
return vm.Status == v1.VMStatusRunning
}), "failed to wait for the VM to start")
// Create a service account whose credentials we'll use to connect to the Controller's SSH server
require.NoError(t, devClient.ServiceAccounts().Create(context.Background(), &v1.ServiceAccount{
Meta: v1.Meta{
Name: "ssh-user",
},
Token: "ssh-password",
Roles: []v1.ServiceAccountRole{
v1.ServiceAccountRoleComputeWrite,
},
}))
// Connect to the VM over Orchard Controller's SSH server
sshAddress, ok := devController.SSHAddress()
require.True(t, ok)
sshClientController, err := ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{
User: "ssh-user",
Auth: []ssh.AuthMethod{
ssh.Password("ssh-password"),
},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if subtle.ConstantTimeCompare(sshPublicKey.Marshal(), key.Marshal()) != 1 {
return fmt.Errorf("untrustred public key was presented by the Controller")
}
return nil
},
})
require.NoError(t, err)
netConnVM, err := sshClientController.Dial("tcp", "test-vm:22")
require.NoError(t, err)
sshConnVM, sshChansVM, sshReqsVM, err := ssh.NewClientConn(netConnVM, "test-vm:22", &ssh.ClientConfig{
User: "admin",
Auth: []ssh.AuthMethod{
ssh.Password("admin"),
},
HostKeyCallback: func(_ string, _ net.Addr, _ ssh.PublicKey) error {
return nil
},
})
require.NoError(t, err)
sshClientVM := ssh.NewClient(sshConnVM, sshChansVM, sshReqsVM)
sshSessVM, err := sshClientVM.NewSession()
require.NoError(t, err)
unameBytes, err := sshSessVM.Output("uname -a")
require.NoError(t, err)
require.Contains(t, string(unameBytes), "Darwin")
}

View File

@ -0,0 +1,24 @@
package wait
import (
"context"
"time"
)
func Wait(duration time.Duration, predicate func() bool) bool {
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
for {
if predicate() {
return true
}
select {
case <-ctx.Done():
return false
case <-time.After(1 * time.Second):
continue
}
}
}