Add VM-scoped temporary JWT access tokens

This commit is contained in:
Fedor Korotkov 2026-02-05 22:58:18 +01:00
parent 448da2a83b
commit 310ff200ea
17 changed files with 1161 additions and 21 deletions

View File

@ -447,6 +447,32 @@ paths:
description: VM resource with the given name doesn't exist
'503':
description: Failed to resolve the IP address on the worker responsible for the specified VM
/vms/{name}/access-tokens:
parameters:
- in: path
name: name
required: true
schema:
type: string
post:
summary: "Issue a temporary VM access token"
tags:
- vms
requestBody:
required: false
content:
application/json:
schema:
$ref: '#/components/schemas/IssueVMAccessTokenRequest'
responses:
'200':
description: VM access token was successfully issued
content:
application/json:
schema:
$ref: '#/components/schemas/VMAccessToken'
'404':
description: VM resource with the given name doesn't exist
components:
schemas:
Worker:
@ -693,6 +719,41 @@ components:
ip:
type: string
description: The resolved IP address
IssueVMAccessTokenRequest:
title: Issue VM Access Token Request
type: object
properties:
ttlSeconds:
type: integer
format: uint64
description: |
Requested token lifetime in seconds. Defaults to 86400 (24h) when omitted.
Maximum allowed value is 2592000 (30d).
VMAccessToken:
title: VM Access Token
type: object
properties:
token:
type: string
description: Signed temporary JWT token
tokenType:
type: string
description: Token type to use in Authorization header
example: Bearer
expiresAt:
type: string
format: date-time
description: RFC3339 timestamp at which the token expires
sshUsername:
type: string
description: Username to use when authenticating to the built-in SSH server
example: token
vmName:
type: string
description: VM name this token was issued for
vmUID:
type: string
description: Immutable VM UID this token is bound to
Event:
title: Generic Resource Event
type: object

View File

@ -17,6 +17,7 @@ func NewCommand() *cobra.Command {
newGetBootstrapTokenCommand(),
newGetClusterSettingsCommand(),
newGetServiceAccountCommand(),
newGetVMAccessTokenCommand(),
newGetVMCommand(),
newGetWorkerCommand(),
)

View File

@ -0,0 +1,58 @@
package get
import (
"fmt"
"time"
"github.com/cirruslabs/orchard/internal/vmtempauth"
"github.com/cirruslabs/orchard/pkg/client"
"github.com/spf13/cobra"
)
var vmAccessTokenTTL time.Duration
func newGetVMAccessTokenCommand() *cobra.Command {
command := &cobra.Command{
Use: "vm-access-token VM_NAME",
Short: "Issue a temporary VM access token",
RunE: runGetVMAccessToken,
Args: cobra.ExactArgs(1),
}
command.Flags().DurationVar(&vmAccessTokenTTL, "ttl", vmtempauth.DefaultTTL,
fmt.Sprintf("token TTL (default: %s, max: %s)", vmtempauth.DefaultTTL, vmtempauth.MaxTTL))
return command
}
func runGetVMAccessToken(cmd *cobra.Command, args []string) error {
name := args[0]
if vmAccessTokenTTL <= 0 {
return fmt.Errorf("%w: --ttl must be greater than 0", ErrGetFailed)
}
if vmAccessTokenTTL > vmtempauth.MaxTTL {
return fmt.Errorf("%w: --ttl cannot exceed %s", ErrGetFailed, vmtempauth.MaxTTL)
}
ttlSeconds := uint64(vmAccessTokenTTL / time.Second)
if ttlSeconds == 0 {
return fmt.Errorf("%w: --ttl is too small", ErrGetFailed)
}
apiClient, err := client.New()
if err != nil {
return err
}
response, err := apiClient.VMs().IssueAccessToken(cmd.Context(), name, client.IssueAccessTokenOptions{
TTLSeconds: &ttlSeconds,
})
if err != nil {
return err
}
fmt.Println(response.Token)
return nil
}

View File

@ -7,10 +7,12 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/cirruslabs/orchard/api"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
"github.com/cirruslabs/orchard/internal/responder"
"github.com/cirruslabs/orchard/internal/vmtempauth"
v1pkg "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/deckarep/golang-set/v2"
@ -24,6 +26,7 @@ import (
)
const ctxServiceAccountKey = "service-account"
const ctxVMAccessTokenClaimsKey = "vm-access-token-claims"
var ErrUnauthorized = errors.New("unauthorized")
@ -183,6 +186,9 @@ func (controller *Controller) initAPI() *gin.Engine {
v1.POST("/vms/:name/events", func(c *gin.Context) {
controller.appendVMEvents(c).Respond(c)
})
v1.POST("/vms/:name/access-tokens", func(c *gin.Context) {
controller.issueVMAccessToken(c).Respond(c)
})
return ginEngine
}
@ -215,6 +221,29 @@ func (controller *Controller) fetchServiceAccount(name string, token string) (*v
}
func (controller *Controller) authenticateMiddleware(c *gin.Context) {
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if len(authHeader) >= len("Bearer ") && strings.EqualFold(authHeader[:len("Bearer ")], "Bearer ") {
token := strings.TrimSpace(authHeader[len("Bearer "):])
if token == "" {
responder.Code(http.StatusUnauthorized).Respond(c)
return
}
claims, err := vmtempauth.Verify(controller.vmAccessTokenKey, token, time.Now().UTC())
if err != nil {
responder.Code(http.StatusUnauthorized).Respond(c)
return
}
c.Set(ctxVMAccessTokenClaimsKey, claims)
c.Next()
return
}
// Retrieve presented credentials (if any)
user, password, ok := c.Request.BasicAuth()
if !ok {
@ -240,6 +269,34 @@ func (controller *Controller) authenticateMiddleware(c *gin.Context) {
c.Next()
}
func (controller *Controller) serviceAccountFromContext(ctx *gin.Context) (*v1pkg.ServiceAccount, bool) {
untypeServiceAccount, ok := ctx.Get(ctxServiceAccountKey)
if !ok {
return nil, false
}
serviceAccount, ok := untypeServiceAccount.(*v1pkg.ServiceAccount)
if !ok {
return nil, false
}
return serviceAccount, true
}
func (controller *Controller) vmAccessTokenClaimsFromContext(ctx *gin.Context) (*vmtempauth.Claims, bool) {
untypeClaims, ok := ctx.Get(ctxVMAccessTokenClaimsKey)
if !ok {
return nil, false
}
claims, ok := untypeClaims.(*vmtempauth.Claims)
if !ok {
return nil, false
}
return claims, true
}
type AuthorizeMode int
const (
@ -270,11 +327,10 @@ func (controller *Controller) authorizeBase(
return nil
}
serviceAccountUntyped, ok := ctx.Get(ctxServiceAccountKey)
serviceAccount, ok := controller.serviceAccountFromContext(ctx)
if !ok {
return responder.Code(http.StatusUnauthorized)
}
serviceAccount := serviceAccountUntyped.(*v1pkg.ServiceAccount)
serviceAccountRolesSet := mapset.NewSet[v1pkg.ServiceAccountRole](serviceAccount.Roles...)
var authorized bool

View File

@ -0,0 +1,73 @@
package controller
import (
"errors"
"io"
"net/http"
"time"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
"github.com/cirruslabs/orchard/internal/responder"
"github.com/cirruslabs/orchard/internal/vmtempauth"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/gin-gonic/gin"
)
func (controller *Controller) issueVMAccessToken(ctx *gin.Context) responder.Responder {
if responder := controller.authorize(ctx, v1.ServiceAccountRoleComputeWrite); responder != nil {
return responder
}
var request v1.IssueVMAccessTokenRequest
if err := ctx.ShouldBindJSON(&request); err != nil && !errors.Is(err, io.EOF) {
return responder.JSON(http.StatusBadRequest, NewErrorResponse("invalid JSON was provided"))
}
ttl, err := vmtempauth.NormalizeTTL(request.TTLSeconds)
if err != nil {
return responder.JSON(http.StatusPreconditionFailed, NewErrorResponse("%v", err))
}
name := ctx.Param("name")
var vm *v1.VM
if responder := controller.storeView(func(txn storepkg.Transaction) responder.Responder {
vm, err = txn.GetVM(name)
if err != nil {
return responder.Error(err)
}
return nil
}); responder != nil {
return responder
}
serviceAccount, ok := controller.serviceAccountFromContext(ctx)
if !ok {
return responder.Code(http.StatusUnauthorized)
}
token, err := vmtempauth.Issue(controller.vmAccessTokenKey, vmtempauth.IssueInput{
Issuer: vmtempauth.AccessTokenIssuer,
Subject: serviceAccount.Name,
VMUID: vm.UID,
VMName: vm.Name,
TTL: ttl,
Now: time.Now().UTC(),
})
if err != nil {
controller.logger.Errorf("failed to issue VM access token: %v", err)
return responder.Code(http.StatusInternalServerError)
}
return responder.JSON(http.StatusOK, &v1.VMAccessToken{
Token: token.Token,
TokenType: "Bearer",
ExpiresAt: token.ExpiresAt,
SSHUsername: vmtempauth.SSHUsername,
VMName: vm.Name,
VMUID: vm.UID,
})
}

View File

@ -8,6 +8,7 @@ import (
"time"
"github.com/cirruslabs/orchard/internal/responder"
"github.com/cirruslabs/orchard/internal/vmtempauth"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/gin-gonic/gin"
@ -15,9 +16,15 @@ import (
)
func (controller *Controller) ip(ctx *gin.Context) responder.Responder {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
vmAccessTokenClaims, vmAccessTokenAuth := controller.vmAccessTokenClaimsFromContext(ctx)
if !vmAccessTokenAuth {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
}
} else if !vmAccessTokenClaims.HasScope(vmtempauth.ScopeVMIP) {
return responder.Code(http.StatusUnauthorized)
}
// Retrieve and parse path and query parameters
@ -36,6 +43,9 @@ func (controller *Controller) ip(ctx *gin.Context) responder.Responder {
if responderImpl != nil {
return responderImpl
}
if vmAccessTokenAuth && !vmAccessTokenClaims.CanAccessVM(vm.UID) {
return responder.JSON(http.StatusForbidden, NewErrorResponse("the VM access token does not allow access to this VM"))
}
// Send an IP resolution request and wait for the result
session := uuid.New().String()

View File

@ -11,6 +11,7 @@ import (
"github.com/cirruslabs/orchard/internal/netconncancel"
"github.com/cirruslabs/orchard/internal/proxy"
"github.com/cirruslabs/orchard/internal/responder"
"github.com/cirruslabs/orchard/internal/vmtempauth"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/coder/websocket"
@ -22,9 +23,15 @@ import (
)
func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responder {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
vmAccessTokenClaims, vmAccessTokenAuth := controller.vmAccessTokenClaimsFromContext(ctx)
if !vmAccessTokenAuth {
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
v1.ServiceAccountRoleComputeConnect); responder != nil {
return responder
}
} else if !vmAccessTokenClaims.HasScope(vmtempauth.ScopeVMPortForward) {
return responder.Code(http.StatusUnauthorized)
}
// Retrieve and parse path and query parameters
@ -52,6 +59,9 @@ func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responde
if responderImpl != nil {
return responderImpl
}
if vmAccessTokenAuth && !vmAccessTokenClaims.CanAccessVM(vm.UID) {
return responder.JSON(http.StatusForbidden, NewErrorResponse("the VM access token does not allow access to this VM"))
}
// Commence port forwarding
return controller.portForward(ctx, waitContext, vm.Worker, vm.UID, uint32(port))

View File

@ -20,6 +20,7 @@ import (
"github.com/cirruslabs/orchard/internal/controller/store/badger"
"github.com/cirruslabs/orchard/internal/netconstants"
"github.com/cirruslabs/orchard/internal/opentelemetry"
"github.com/cirruslabs/orchard/internal/vmtempauth"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/samber/lo"
@ -67,6 +68,7 @@ type Controller struct {
pingInterval time.Duration
prometheusMetrics bool
synthetic bool
vmAccessTokenKey []byte
sshListenAddr string
sshSigner ssh.Signer
@ -115,6 +117,10 @@ func New(opts ...Option) (*Controller, error) {
if controller.logger == nil {
controller.logger = zap.NewNop().Sugar()
}
var err error
if controller.vmAccessTokenKey, err = controller.loadOrCreateVMAccessTokenKey(); err != nil {
return nil, fmt.Errorf("%w: failed to initialize VM access token key: %v", ErrInitFailed, err)
}
// Instantiate the database
store, err := badger.NewBadgerStore(controller.dataDir.DBPath(), controller.disableDBCompression,
@ -137,7 +143,8 @@ func New(opts ...Option) (*Controller, error) {
// Instantiate the SSH server (if configured)
if controller.sshListenAddr != "" && controller.sshSigner != nil {
controller.sshServer, err = sshserver.NewSSHServer(controller.sshListenAddr, controller.sshSigner,
store, controller.connRendezvous, controller.workerNotifier, controller.sshNoClientAuth, controller.logger)
store, controller.connRendezvous, controller.workerNotifier, controller.vmAccessTokenKey,
controller.sshNoClientAuth, controller.logger)
if err != nil {
return nil, err
}
@ -197,6 +204,32 @@ func New(opts ...Option) (*Controller, error) {
return controller, nil
}
func (controller *Controller) loadOrCreateVMAccessTokenKey() ([]byte, error) {
signingKey, err := controller.dataDir.VMAccessTokenSigningKey()
if err == nil {
if len(signingKey) != vmtempauth.SigningKeySizeBytes {
return nil, fmt.Errorf("%w: expected %d bytes, got %d",
vmtempauth.ErrInvalidSigningKey, vmtempauth.SigningKeySizeBytes, len(signingKey))
}
return signingKey, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
signingKey, err = vmtempauth.NewSigningKey()
if err != nil {
return nil, err
}
if err := controller.dataDir.SetVMAccessTokenSigningKey(signingKey); err != nil {
return nil, err
}
return signingKey, nil
}
func (controller *Controller) ServiceAccounts() ([]v1.ServiceAccount, error) {
var serviceAccounts []v1.ServiceAccount
var err error

View File

@ -99,6 +99,18 @@ func (dataDir *DataDir) SSHHostKeyPath() string {
return filepath.Join(dataDir.path, "ssh_host_ed25519_key")
}
func (dataDir *DataDir) VMAccessTokenSigningKeyPath() string {
return filepath.Join(dataDir.path, "vm_access_token_signing_key")
}
func (dataDir *DataDir) VMAccessTokenSigningKey() ([]byte, error) {
return os.ReadFile(dataDir.VMAccessTokenSigningKeyPath())
}
func (dataDir *DataDir) SetVMAccessTokenSigningKey(signingKey []byte) error {
return os.WriteFile(dataDir.VMAccessTokenSigningKeyPath(), signingKey, 0600)
}
func (dataDir *DataDir) Initialized() (bool, error) {
dataDirEntries, err := os.ReadDir(dataDir.path)
if err != nil {

View File

@ -13,6 +13,7 @@ import (
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
"github.com/cirruslabs/orchard/internal/proxy"
"github.com/cirruslabs/orchard/internal/vmtempauth"
"github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/cirruslabs/orchard/rpc"
"github.com/google/uuid"
@ -27,15 +28,21 @@ const (
//
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
channelTypeDirectTCPIP = "direct-tcpip"
permissionPrincipalTypeExt = "orchard-principal-type"
permissionVMUIDExt = "orchard-vm-uid"
principalTypeVMAccessToken = "vm-access-token"
)
type SSHServer struct {
listener net.Listener
serverConfig *ssh.ServerConfig
store storepkg.Store
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]]
workerNotifier *notifier.Notifier
logger *zap.SugaredLogger
listener net.Listener
serverConfig *ssh.ServerConfig
store storepkg.Store
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]]
workerNotifier *notifier.Notifier
vmAccessTokenKey []byte
logger *zap.SugaredLogger
}
func NewSSHServer(
@ -44,14 +51,16 @@ func NewSSHServer(
store storepkg.Store,
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]],
workerNotifier *notifier.Notifier,
vmAccessTokenKey []byte,
noClientAuth bool,
logger *zap.SugaredLogger,
) (*SSHServer, error) {
server := &SSHServer{
store: store,
connRendezvous: connRendezvous,
workerNotifier: workerNotifier,
logger: logger,
store: store,
connRendezvous: connRendezvous,
workerNotifier: workerNotifier,
vmAccessTokenKey: vmAccessTokenKey,
logger: logger,
}
listener, err := net.Listen("tcp", address)
@ -87,6 +96,24 @@ func (server *SSHServer) Address() string {
}
func (server *SSHServer) passwordCallback(connMetadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if connMetadata.User() == vmtempauth.SSHUsername {
claims, err := vmtempauth.Verify(server.vmAccessTokenKey, string(password), time.Now().UTC())
if err != nil {
return nil, fmt.Errorf("authentication failed for user %q: invalid VM access token", connMetadata.User())
}
if !claims.HasScope(vmtempauth.ScopeVMSSHJumpbox) {
return nil, fmt.Errorf("authorization failed for user %q: token lacks %q scope",
connMetadata.User(), vmtempauth.ScopeVMSSHJumpbox)
}
return &ssh.Permissions{
Extensions: map[string]string{
permissionPrincipalTypeExt: principalTypeVMAccessToken,
permissionVMUIDExt: claims.VMUID,
},
}, nil
}
if err := server.store.View(func(txn storepkg.Transaction) error {
// Authenticate
server.logger.Debugf("authenticating user %q using the password authentication",
@ -157,7 +184,7 @@ func (server *SSHServer) handleConnection(conn net.Conn) {
server.logger.Debugf("handling a new direct TCP/IP channel for user %q connecting from %q",
sshConn.User(), sshConn.RemoteAddr().String())
go server.handleDirectTCPIP(connCtx, newChannel)
go server.handleDirectTCPIP(connCtx, sshConn, newChannel)
default:
message := fmt.Sprintf("unsupported channel type requested: %q", newChannel.ChannelType())
@ -188,7 +215,7 @@ func (server *SSHServer) handleConnection(conn net.Conn) {
}
}
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.NewChannel) {
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, sshConn *ssh.ServerConn, newChannel ssh.NewChannel) {
// Unmarshal the payload to determine to which VM the user wants to connect to
//
// This direct TCP/IP channel's payload is documented
@ -234,6 +261,20 @@ func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.N
return
}
if sshConn.Permissions != nil &&
sshConn.Permissions.Extensions[permissionPrincipalTypeExt] == principalTypeVMAccessToken {
tokenVMUID := sshConn.Permissions.Extensions[permissionVMUIDExt]
if tokenVMUID == "" || tokenVMUID != vm.UID {
message := "authorization failed for requested VM"
if err := newChannel.Reject(ssh.Prohibited, message); err != nil {
server.logger.Warnf("failed to reject the new channel due to VM access token mismatch: %v", err)
}
return
}
}
// The user wants to connect to an existing VM, request and wait
// for a connection with the worker before accepting the channel
session := uuid.New().String()

View File

@ -0,0 +1,337 @@
package tests_test
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"testing"
"time"
"github.com/cirruslabs/orchard/internal/controller"
"github.com/cirruslabs/orchard/internal/echoserver"
"github.com/cirruslabs/orchard/internal/imageconstant"
"github.com/cirruslabs/orchard/internal/tests/wait"
"github.com/cirruslabs/orchard/internal/vmtempauth"
workerpkg "github.com/cirruslabs/orchard/internal/worker"
"github.com/cirruslabs/orchard/pkg/client"
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
)
type authenticatedTestEnvironment struct {
adminClient *client.Client
controller *controller.Controller
worker *workerpkg.Worker
}
func TestVMAccessTokenAPI(t *testing.T) {
env := startAuthenticatedTestEnvironment(t, false)
defer func() {
_ = env.worker.Close()
}()
vmA := createRunningSyntheticVM(t, env.adminClient, "test-vm-a")
vmB := createRunningSyntheticVM(t, env.adminClient, "test-vm-b")
echoServer, err := echoserver.New()
require.NoError(t, err)
echoContext, echoCancel := context.WithCancel(context.Background())
defer echoCancel()
go func() {
_ = echoServer.Run(echoContext)
}()
echoPort := parsePort(t, echoServer.Addr())
tokenA := issueAccessToken(t, env.adminClient, vmA.Name, nil)
tokenClient, err := client.New(client.WithAddress(env.controller.Address()), client.WithBearerToken(tokenA.Token))
require.NoError(t, err)
ip, err := tokenClient.VMs().IP(context.Background(), vmA.Name, 30)
require.NoError(t, err)
require.NotEmpty(t, ip)
vmAConn, err := tokenClient.VMs().PortForward(context.Background(), vmA.Name, echoPort, 30)
require.NoError(t, err)
t.Cleanup(func() {
_ = vmAConn.Close()
})
require.NoError(t, vmAConn.SetDeadline(time.Now().Add(10*time.Second)))
_, err = vmAConn.Write([]byte("hello"))
require.NoError(t, err)
result := make([]byte, len("hello"))
_, err = io.ReadFull(vmAConn, result)
require.NoError(t, err)
require.Equal(t, "hello", string(result))
_, err = tokenClient.VMs().IP(context.Background(), vmB.Name, 30)
require.Error(t, err)
require.ErrorIs(t, err, client.ErrAPI)
require.Contains(t, err.Error(), "403")
_, err = tokenClient.VMs().List(context.Background())
require.Error(t, err)
require.ErrorIs(t, err, client.ErrAPI)
require.Contains(t, err.Error(), "401")
}
func TestVMAccessTokenVMRecreation(t *testing.T) {
env := startAuthenticatedTestEnvironment(t, false)
defer func() {
_ = env.worker.Close()
}()
vm := createRunningSyntheticVM(t, env.adminClient, "test-vm")
token := issueAccessToken(t, env.adminClient, vm.Name, nil)
tokenClient, err := client.New(client.WithAddress(env.controller.Address()), client.WithBearerToken(token.Token))
require.NoError(t, err)
require.NoError(t, env.adminClient.VMs().Delete(context.Background(), vm.Name))
createRunningSyntheticVM(t, env.adminClient, vm.Name)
_, err = tokenClient.VMs().IP(context.Background(), vm.Name, 30)
require.Error(t, err)
require.ErrorIs(t, err, client.ErrAPI)
require.Contains(t, err.Error(), "403")
}
func TestVMAccessTokenSSHServer(t *testing.T) {
env := startAuthenticatedTestEnvironment(t, true)
defer func() {
_ = env.worker.Close()
}()
vmA := createRunningSyntheticVM(t, env.adminClient, "test-vm-a")
vmB := createRunningSyntheticVM(t, env.adminClient, "test-vm-b")
echoServer, err := echoserver.New()
require.NoError(t, err)
echoContext, echoCancel := context.WithCancel(context.Background())
defer echoCancel()
go func() {
_ = echoServer.Run(echoContext)
}()
echoPort := parsePort(t, echoServer.Addr())
tokenA := issueAccessToken(t, env.adminClient, vmA.Name, nil)
sshAddress, ok := env.controller.SSHAddress()
require.True(t, ok)
sshClientController, err := ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{
User: vmtempauth.SSHUsername,
Auth: []ssh.AuthMethod{
ssh.Password(tokenA.Token),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = sshClientController.Close()
})
vmAConn, err := sshClientController.Dial("tcp", fmt.Sprintf("%s:%d", vmA.Name, echoPort))
require.NoError(t, err)
t.Cleanup(func() {
_ = vmAConn.Close()
})
_, err = vmAConn.Write([]byte("jumpbox"))
require.NoError(t, err)
vmAResult := make([]byte, len("jumpbox"))
_, err = io.ReadFull(vmAConn, vmAResult)
require.NoError(t, err)
require.Equal(t, "jumpbox", string(vmAResult))
_, err = sshClientController.Dial("tcp", fmt.Sprintf("%s:%d", vmB.Name, echoPort))
require.Error(t, err)
shortTTL := uint64(1)
shortToken := issueAccessToken(t, env.adminClient, vmA.Name, &shortTTL)
time.Sleep(2 * time.Second)
_, err = ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{
User: vmtempauth.SSHUsername,
Auth: []ssh.AuthMethod{
ssh.Password(shortToken.Token),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 10 * time.Second,
})
require.Error(t, err)
}
func startAuthenticatedTestEnvironment(
t *testing.T,
withSSHServer bool,
) *authenticatedTestEnvironment {
t.Helper()
dataDir, err := controller.NewDataDir(t.TempDir())
require.NoError(t, err)
controllerOpts := []controller.Option{
controller.WithDataDir(dataDir),
controller.WithListenAddr(":0"),
controller.WithExperimentalRPCV2(),
}
if withSSHServer {
_, privateKey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(privateKey)
require.NoError(t, err)
controllerOpts = append(controllerOpts, controller.WithSSHServer(":0", signer, false))
}
controllerInstance, err := controller.New(controllerOpts...)
require.NoError(t, err)
require.NoError(t, controllerInstance.EnsureServiceAccount(&v1.ServiceAccount{
Meta: v1.Meta{
Name: "admin",
},
Token: "admin-token",
Roles: v1.AllServiceAccountRoles(),
}))
require.NoError(t, controllerInstance.EnsureServiceAccount(&v1.ServiceAccount{
Meta: v1.Meta{
Name: "worker",
},
Token: "worker-token",
Roles: []v1.ServiceAccountRole{
v1.ServiceAccountRoleComputeRead,
v1.ServiceAccountRoleComputeWrite,
},
}))
adminClient, err := client.New(
client.WithAddress(controllerInstance.Address()),
client.WithCredentials("admin", "admin-token"),
)
require.NoError(t, err)
workerClient, err := client.New(
client.WithAddress(controllerInstance.Address()),
client.WithCredentials("worker", "worker-token"),
)
require.NoError(t, err)
workerInstance, err := workerpkg.New(workerClient, workerpkg.WithName("worker-a"), workerpkg.WithSynthetic())
require.NoError(t, err)
testContext, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
go func() {
runErr := controllerInstance.Run(testContext)
if runErr != nil && !errors.Is(runErr, context.Canceled) && !errors.Is(runErr, http.ErrServerClosed) {
t.Errorf("controller failed: %v", runErr)
}
}()
go func() {
runErr := workerInstance.Run(testContext)
if runErr != nil && !errors.Is(runErr, context.Canceled) {
t.Errorf("worker failed: %v", runErr)
}
}()
assert.True(t, wait.Wait(30*time.Second, func() bool {
workers, err := adminClient.Workers().List(context.Background())
if err != nil {
return false
}
return len(workers) == 1
}), "failed to wait for worker to register")
return &authenticatedTestEnvironment{
adminClient: adminClient,
controller: controllerInstance,
worker: workerInstance,
}
}
func createRunningSyntheticVM(t *testing.T, apiClient *client.Client, name string) *v1.VM {
t.Helper()
err := apiClient.VMs().Create(context.Background(), &v1.VM{
Meta: v1.Meta{
Name: name,
},
Image: imageconstant.DefaultMacosImage,
CPU: 1,
Memory: 512,
Headless: true,
Status: v1.VMStatusPending,
})
require.NoError(t, err)
require.Truef(t, wait.Wait(60*time.Second, func() bool {
vm, err := apiClient.VMs().Get(context.Background(), name)
if err != nil {
return false
}
return vm.Status == v1.VMStatusRunning
}), "failed to wait for VM %q to reach running state", name)
vm, err := apiClient.VMs().Get(context.Background(), name)
require.NoError(t, err)
return vm
}
func issueAccessToken(
t *testing.T,
apiClient *client.Client,
vmName string,
ttlSeconds *uint64,
) *v1.VMAccessToken {
t.Helper()
token, err := apiClient.VMs().IssueAccessToken(context.Background(), vmName, client.IssueAccessTokenOptions{
TTLSeconds: ttlSeconds,
})
require.NoError(t, err)
require.NotEmpty(t, token.Token)
return token
}
func parsePort(t *testing.T, addr string) uint16 {
t.Helper()
_, portRaw, err := net.SplitHostPort(addr)
require.NoError(t, err)
port, err := strconv.ParseUint(portRaw, 10, 16)
require.NoError(t, err)
return uint16(port)
}

View File

@ -0,0 +1,306 @@
package vmtempauth
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/google/uuid"
)
const (
SigningKeySizeBytes = 32
JWTHeaderAlgHS256 = "HS256"
JWTHeaderTypJWT = "JWT"
AccessTokenAudience = "orchard-vm-access"
AccessTokenIssuer = "orchard-controller"
ScopeVMPortForward = "vm:port-forward"
ScopeVMIP = "vm:ip"
ScopeVMSSHJumpbox = "vm:ssh-jumpbox"
PortsAny = "*"
DefaultTTL = 24 * time.Hour
MaxTTL = 30 * 24 * time.Hour
SSHUsername = "token"
)
var (
ErrInvalidSigningKey = errors.New("invalid signing key")
ErrMalformedToken = errors.New("malformed token")
ErrInvalidTokenHeader = errors.New("invalid token header")
ErrInvalidTokenClaims = errors.New("invalid token claims")
ErrSignatureMismatch = errors.New("token signature mismatch")
ErrTokenExpired = errors.New("token expired")
ErrTokenNotYetValid = errors.New("token not yet valid")
ErrInvalidTTL = errors.New("invalid access token TTL")
encoding = base64.RawURLEncoding
requiredScopes = []string{ScopeVMPortForward, ScopeVMIP, ScopeVMSSHJumpbox}
)
type header struct {
Alg string `json:"alg"`
Typ string `json:"typ"`
}
type Claims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience []string `json:"aud"`
IssuedAt int64 `json:"iat"`
NotBefore int64 `json:"nbf"`
ExpiresAt int64 `json:"exp"`
JTI string `json:"jti"`
VMUID string `json:"vm_uid"`
VMName string `json:"vm_name,omitempty"`
Scopes []string `json:"scopes"`
Ports string `json:"ports,omitempty"`
}
func (claims Claims) HasScope(scope string) bool {
return slices.Contains(claims.Scopes, scope)
}
func (claims Claims) CanAccessVM(vmUID string) bool {
return claims.VMUID != "" && vmUID != "" && claims.VMUID == vmUID
}
type IssueInput struct {
Issuer string
Subject string
VMUID string
VMName string
TTL time.Duration
Now time.Time
}
type IssueOutput struct {
Token string
Claims Claims
ExpiresAt time.Time
}
func NewSigningKey() ([]byte, error) {
key := make([]byte, SigningKeySizeBytes)
if _, err := rand.Read(key); err != nil {
return nil, err
}
return key, nil
}
func NormalizeTTL(ttlSeconds *uint64) (time.Duration, error) {
if ttlSeconds == nil {
return DefaultTTL, nil
}
if *ttlSeconds == 0 {
return 0, fmt.Errorf("%w: TTL cannot be zero", ErrInvalidTTL)
}
ttl := time.Duration(*ttlSeconds) * time.Second
if ttl > MaxTTL {
return 0, fmt.Errorf("%w: maximum allowed TTL is %s", ErrInvalidTTL, MaxTTL)
}
return ttl, nil
}
func Issue(signingKey []byte, input IssueInput) (*IssueOutput, error) {
if len(signingKey) != SigningKeySizeBytes {
return nil, ErrInvalidSigningKey
}
if input.Subject == "" {
return nil, fmt.Errorf("%w: missing subject", ErrInvalidTokenClaims)
}
if input.VMUID == "" {
return nil, fmt.Errorf("%w: missing vm_uid", ErrInvalidTokenClaims)
}
ttl := input.TTL
if ttl == 0 {
ttl = DefaultTTL
}
if ttl < 0 || ttl > MaxTTL {
return nil, fmt.Errorf("%w: unsupported TTL %s", ErrInvalidTTL, ttl)
}
now := input.Now
if now.IsZero() {
now = time.Now().UTC()
}
issuer := input.Issuer
if issuer == "" {
issuer = AccessTokenIssuer
}
claims := Claims{
Issuer: issuer,
Subject: input.Subject,
Audience: []string{AccessTokenAudience},
IssuedAt: now.Unix(),
NotBefore: now.Unix(),
ExpiresAt: now.Add(ttl).Unix(),
JTI: uuid.NewString(),
VMUID: input.VMUID,
VMName: input.VMName,
Scopes: append([]string{}, requiredScopes...),
Ports: PortsAny,
}
token, err := encodeAndSign(signingKey, claims)
if err != nil {
return nil, err
}
return &IssueOutput{
Token: token,
Claims: claims,
ExpiresAt: time.Unix(claims.ExpiresAt, 0).UTC(),
}, nil
}
func Verify(signingKey []byte, token string, now time.Time) (*Claims, error) {
if len(signingKey) != SigningKeySizeBytes {
return nil, ErrInvalidSigningKey
}
token = strings.TrimSpace(token)
if token == "" {
return nil, ErrMalformedToken
}
splits := strings.Split(token, ".")
if len(splits) != 3 {
return nil, ErrMalformedToken
}
headerRaw, err := encoding.DecodeString(splits[0])
if err != nil {
return nil, ErrMalformedToken
}
var tokenHeader header
if err := json.Unmarshal(headerRaw, &tokenHeader); err != nil {
return nil, ErrMalformedToken
}
if tokenHeader.Alg != JWTHeaderAlgHS256 || tokenHeader.Typ != JWTHeaderTypJWT {
return nil, ErrInvalidTokenHeader
}
signedPart := fmt.Sprintf("%s.%s", splits[0], splits[1])
expectedSignature := sign(signingKey, signedPart)
presentedSignature, err := encoding.DecodeString(splits[2])
if err != nil {
return nil, ErrMalformedToken
}
if subtle.ConstantTimeCompare(expectedSignature, presentedSignature) != 1 {
return nil, ErrSignatureMismatch
}
claimsRaw, err := encoding.DecodeString(splits[1])
if err != nil {
return nil, ErrMalformedToken
}
var claims Claims
if err := json.Unmarshal(claimsRaw, &claims); err != nil {
return nil, ErrMalformedToken
}
if now.IsZero() {
now = time.Now().UTC()
}
if err := validateClaims(claims, now); err != nil {
return nil, err
}
return &claims, nil
}
func encodeAndSign(signingKey []byte, claims Claims) (string, error) {
headerJSON, err := json.Marshal(header{
Alg: JWTHeaderAlgHS256,
Typ: JWTHeaderTypJWT,
})
if err != nil {
return "", err
}
claimsJSON, err := json.Marshal(claims)
if err != nil {
return "", err
}
signedPart := fmt.Sprintf("%s.%s",
encoding.EncodeToString(headerJSON),
encoding.EncodeToString(claimsJSON),
)
signature := sign(signingKey, signedPart)
return fmt.Sprintf("%s.%s", signedPart, encoding.EncodeToString(signature)), nil
}
func sign(signingKey []byte, signedPart string) []byte {
mac := hmac.New(sha256.New, signingKey)
_, _ = mac.Write([]byte(signedPart))
return mac.Sum(nil)
}
func validateClaims(claims Claims, now time.Time) error {
if claims.Issuer == "" || claims.Subject == "" || claims.JTI == "" || claims.VMUID == "" {
return ErrInvalidTokenClaims
}
if claims.IssuedAt == 0 || claims.NotBefore == 0 || claims.ExpiresAt == 0 {
return ErrInvalidTokenClaims
}
if !slices.Contains(claims.Audience, AccessTokenAudience) {
return ErrInvalidTokenClaims
}
if claims.NotBefore > now.Unix() {
return ErrTokenNotYetValid
}
if claims.ExpiresAt <= now.Unix() {
return ErrTokenExpired
}
for _, requiredScope := range requiredScopes {
if !claims.HasScope(requiredScope) {
return ErrInvalidTokenClaims
}
}
if claims.Ports != "" && claims.Ports != PortsAny {
return ErrInvalidTokenClaims
}
return nil
}

View File

@ -0,0 +1,89 @@
package vmtempauth_test
import (
"errors"
"testing"
"time"
"github.com/cirruslabs/orchard/internal/vmtempauth"
"github.com/stretchr/testify/require"
)
func TestIssueAndVerify(t *testing.T) {
signingKey, err := vmtempauth.NewSigningKey()
require.NoError(t, err)
now := time.Unix(1735779600, 0).UTC()
issued, err := vmtempauth.Issue(signingKey, vmtempauth.IssueInput{
Subject: "issuer",
VMUID: "vm-uid",
VMName: "vm-name",
TTL: 10 * time.Minute,
Now: now,
})
require.NoError(t, err)
claims, err := vmtempauth.Verify(signingKey, issued.Token, now.Add(time.Minute))
require.NoError(t, err)
require.Equal(t, "issuer", claims.Subject)
require.Equal(t, "vm-uid", claims.VMUID)
require.Equal(t, "vm-name", claims.VMName)
require.True(t, claims.HasScope(vmtempauth.ScopeVMPortForward))
require.True(t, claims.HasScope(vmtempauth.ScopeVMIP))
require.True(t, claims.HasScope(vmtempauth.ScopeVMSSHJumpbox))
require.True(t, claims.CanAccessVM("vm-uid"))
}
func TestVerifyExpired(t *testing.T) {
signingKey, err := vmtempauth.NewSigningKey()
require.NoError(t, err)
now := time.Unix(1735779600, 0).UTC()
issued, err := vmtempauth.Issue(signingKey, vmtempauth.IssueInput{
Subject: "issuer",
VMUID: "vm-uid",
TTL: time.Second,
Now: now,
})
require.NoError(t, err)
_, err = vmtempauth.Verify(signingKey, issued.Token, now.Add(2*time.Second))
require.Error(t, err)
require.ErrorIs(t, err, vmtempauth.ErrTokenExpired)
}
func TestVerifyBadSignature(t *testing.T) {
signingKey, err := vmtempauth.NewSigningKey()
require.NoError(t, err)
issued, err := vmtempauth.Issue(signingKey, vmtempauth.IssueInput{
Subject: "issuer",
VMUID: "vm-uid",
TTL: time.Minute,
Now: time.Now().UTC(),
})
require.NoError(t, err)
tampered := issued.Token[:len(issued.Token)-1] + "x"
_, err = vmtempauth.Verify(signingKey, tampered, time.Now().UTC())
require.Error(t, err)
require.ErrorIs(t, err, vmtempauth.ErrSignatureMismatch)
}
func TestNormalizeTTL(t *testing.T) {
defaultTTL, err := vmtempauth.NormalizeTTL(nil)
require.NoError(t, err)
require.Equal(t, vmtempauth.DefaultTTL, defaultTTL)
zero := uint64(0)
_, err = vmtempauth.NormalizeTTL(&zero)
require.Error(t, err)
require.True(t, errors.Is(err, vmtempauth.ErrInvalidTTL))
tooLong := uint64(vmtempauth.MaxTTL/time.Second) + 1
_, err = vmtempauth.NormalizeTTL(&tooLong)
require.Error(t, err)
require.True(t, errors.Is(err, vmtempauth.ErrInvalidTTL))
}

View File

@ -42,6 +42,7 @@ type Client struct {
serviceAccountName string
serviceAccountToken string
bearerToken string
dialer dialer.Dialer
}
@ -321,6 +322,12 @@ func (client *Client) formatPath(path string) *url.URL {
func (client *Client) modifyHeader(header http.Header) {
header.Set("User-Agent", fmt.Sprintf("Orchard/%s", version.FullVersion))
if client.bearerToken != "" {
header.Set("Authorization", fmt.Sprintf("Bearer %s", client.bearerToken))
return
}
if client.serviceAccountName != "" && client.serviceAccountToken != "" {
authPlain := fmt.Sprintf("%s:%s", client.serviceAccountName, client.serviceAccountToken)
authEncoded := base64.StdEncoding.EncodeToString([]byte(authPlain))

View File

@ -30,6 +30,12 @@ func WithCredentials(serviceAccountName string, serviceAccountToken string) Opti
}
}
func WithBearerToken(token string) Option {
return func(client *Client) {
client.bearerToken = token
}
}
func WithDialer(dialer dialer.Dialer) Option {
return func(client *Client) {
client.dialer = dialer

View File

@ -46,6 +46,10 @@ type EventsPageOptions struct {
Cursor string
}
type IssueAccessTokenOptions struct {
TTLSeconds *uint64
}
func (service *VMsService) Create(ctx context.Context, vm *v1.VM) error {
err := service.client.request(ctx, http.MethodPost, "vms",
vm, nil, nil)
@ -172,6 +176,26 @@ func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint
return result.IP, nil
}
func (service *VMsService) IssueAccessToken(
ctx context.Context,
name string,
options IssueAccessTokenOptions,
) (*v1.VMAccessToken, error) {
request := v1.IssueVMAccessTokenRequest{
TTLSeconds: options.TTLSeconds,
}
var token v1.VMAccessToken
err := service.client.request(ctx, http.MethodPost, fmt.Sprintf("vms/%s/access-tokens", url.PathEscape(name)),
request, &token, nil)
if err != nil {
return nil, err
}
return &token, nil
}
func (service *VMsService) StreamEvents(name string) *EventStreamer {
return NewEventStreamer(service.client, fmt.Sprintf("vms/%s/events", url.PathEscape(name)))
}

View File

@ -0,0 +1,16 @@
package v1
import "time"
type IssueVMAccessTokenRequest struct {
TTLSeconds *uint64 `json:"ttlSeconds,omitempty"`
}
type VMAccessToken struct {
Token string `json:"token,omitempty"`
TokenType string `json:"tokenType,omitempty"`
ExpiresAt time.Time `json:"expiresAt,omitempty"`
SSHUsername string `json:"sshUsername,omitempty"`
VMName string `json:"vmName,omitempty"`
VMUID string `json:"vmUID,omitempty"`
}