Add VM-scoped temporary JWT access tokens
This commit is contained in:
parent
448da2a83b
commit
310ff200ea
|
|
@ -447,6 +447,32 @@ paths:
|
|||
description: VM resource with the given name doesn't exist
|
||||
'503':
|
||||
description: Failed to resolve the IP address on the worker responsible for the specified VM
|
||||
/vms/{name}/access-tokens:
|
||||
parameters:
|
||||
- in: path
|
||||
name: name
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
post:
|
||||
summary: "Issue a temporary VM access token"
|
||||
tags:
|
||||
- vms
|
||||
requestBody:
|
||||
required: false
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/IssueVMAccessTokenRequest'
|
||||
responses:
|
||||
'200':
|
||||
description: VM access token was successfully issued
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/VMAccessToken'
|
||||
'404':
|
||||
description: VM resource with the given name doesn't exist
|
||||
components:
|
||||
schemas:
|
||||
Worker:
|
||||
|
|
@ -693,6 +719,41 @@ components:
|
|||
ip:
|
||||
type: string
|
||||
description: The resolved IP address
|
||||
IssueVMAccessTokenRequest:
|
||||
title: Issue VM Access Token Request
|
||||
type: object
|
||||
properties:
|
||||
ttlSeconds:
|
||||
type: integer
|
||||
format: uint64
|
||||
description: |
|
||||
Requested token lifetime in seconds. Defaults to 86400 (24h) when omitted.
|
||||
Maximum allowed value is 2592000 (30d).
|
||||
VMAccessToken:
|
||||
title: VM Access Token
|
||||
type: object
|
||||
properties:
|
||||
token:
|
||||
type: string
|
||||
description: Signed temporary JWT token
|
||||
tokenType:
|
||||
type: string
|
||||
description: Token type to use in Authorization header
|
||||
example: Bearer
|
||||
expiresAt:
|
||||
type: string
|
||||
format: date-time
|
||||
description: RFC3339 timestamp at which the token expires
|
||||
sshUsername:
|
||||
type: string
|
||||
description: Username to use when authenticating to the built-in SSH server
|
||||
example: token
|
||||
vmName:
|
||||
type: string
|
||||
description: VM name this token was issued for
|
||||
vmUID:
|
||||
type: string
|
||||
description: Immutable VM UID this token is bound to
|
||||
Event:
|
||||
title: Generic Resource Event
|
||||
type: object
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ func NewCommand() *cobra.Command {
|
|||
newGetBootstrapTokenCommand(),
|
||||
newGetClusterSettingsCommand(),
|
||||
newGetServiceAccountCommand(),
|
||||
newGetVMAccessTokenCommand(),
|
||||
newGetVMCommand(),
|
||||
newGetWorkerCommand(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
package get
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
"github.com/cirruslabs/orchard/pkg/client"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var vmAccessTokenTTL time.Duration
|
||||
|
||||
func newGetVMAccessTokenCommand() *cobra.Command {
|
||||
command := &cobra.Command{
|
||||
Use: "vm-access-token VM_NAME",
|
||||
Short: "Issue a temporary VM access token",
|
||||
RunE: runGetVMAccessToken,
|
||||
Args: cobra.ExactArgs(1),
|
||||
}
|
||||
|
||||
command.Flags().DurationVar(&vmAccessTokenTTL, "ttl", vmtempauth.DefaultTTL,
|
||||
fmt.Sprintf("token TTL (default: %s, max: %s)", vmtempauth.DefaultTTL, vmtempauth.MaxTTL))
|
||||
|
||||
return command
|
||||
}
|
||||
|
||||
func runGetVMAccessToken(cmd *cobra.Command, args []string) error {
|
||||
name := args[0]
|
||||
|
||||
if vmAccessTokenTTL <= 0 {
|
||||
return fmt.Errorf("%w: --ttl must be greater than 0", ErrGetFailed)
|
||||
}
|
||||
if vmAccessTokenTTL > vmtempauth.MaxTTL {
|
||||
return fmt.Errorf("%w: --ttl cannot exceed %s", ErrGetFailed, vmtempauth.MaxTTL)
|
||||
}
|
||||
|
||||
ttlSeconds := uint64(vmAccessTokenTTL / time.Second)
|
||||
if ttlSeconds == 0 {
|
||||
return fmt.Errorf("%w: --ttl is too small", ErrGetFailed)
|
||||
}
|
||||
|
||||
apiClient, err := client.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response, err := apiClient.VMs().IssueAccessToken(cmd.Context(), name, client.IssueAccessTokenOptions{
|
||||
TTLSeconds: &ttlSeconds,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println(response.Token)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -7,10 +7,12 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/api"
|
||||
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
||||
"github.com/cirruslabs/orchard/internal/responder"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
v1pkg "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"github.com/deckarep/golang-set/v2"
|
||||
|
|
@ -24,6 +26,7 @@ import (
|
|||
)
|
||||
|
||||
const ctxServiceAccountKey = "service-account"
|
||||
const ctxVMAccessTokenClaimsKey = "vm-access-token-claims"
|
||||
|
||||
var ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
|
|
@ -183,6 +186,9 @@ func (controller *Controller) initAPI() *gin.Engine {
|
|||
v1.POST("/vms/:name/events", func(c *gin.Context) {
|
||||
controller.appendVMEvents(c).Respond(c)
|
||||
})
|
||||
v1.POST("/vms/:name/access-tokens", func(c *gin.Context) {
|
||||
controller.issueVMAccessToken(c).Respond(c)
|
||||
})
|
||||
|
||||
return ginEngine
|
||||
}
|
||||
|
|
@ -215,6 +221,29 @@ func (controller *Controller) fetchServiceAccount(name string, token string) (*v
|
|||
}
|
||||
|
||||
func (controller *Controller) authenticateMiddleware(c *gin.Context) {
|
||||
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
|
||||
if len(authHeader) >= len("Bearer ") && strings.EqualFold(authHeader[:len("Bearer ")], "Bearer ") {
|
||||
token := strings.TrimSpace(authHeader[len("Bearer "):])
|
||||
if token == "" {
|
||||
responder.Code(http.StatusUnauthorized).Respond(c)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := vmtempauth.Verify(controller.vmAccessTokenKey, token, time.Now().UTC())
|
||||
if err != nil {
|
||||
responder.Code(http.StatusUnauthorized).Respond(c)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(ctxVMAccessTokenClaimsKey, claims)
|
||||
c.Next()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve presented credentials (if any)
|
||||
user, password, ok := c.Request.BasicAuth()
|
||||
if !ok {
|
||||
|
|
@ -240,6 +269,34 @@ func (controller *Controller) authenticateMiddleware(c *gin.Context) {
|
|||
c.Next()
|
||||
}
|
||||
|
||||
func (controller *Controller) serviceAccountFromContext(ctx *gin.Context) (*v1pkg.ServiceAccount, bool) {
|
||||
untypeServiceAccount, ok := ctx.Get(ctxServiceAccountKey)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
serviceAccount, ok := untypeServiceAccount.(*v1pkg.ServiceAccount)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return serviceAccount, true
|
||||
}
|
||||
|
||||
func (controller *Controller) vmAccessTokenClaimsFromContext(ctx *gin.Context) (*vmtempauth.Claims, bool) {
|
||||
untypeClaims, ok := ctx.Get(ctxVMAccessTokenClaimsKey)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
claims, ok := untypeClaims.(*vmtempauth.Claims)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return claims, true
|
||||
}
|
||||
|
||||
type AuthorizeMode int
|
||||
|
||||
const (
|
||||
|
|
@ -270,11 +327,10 @@ func (controller *Controller) authorizeBase(
|
|||
return nil
|
||||
}
|
||||
|
||||
serviceAccountUntyped, ok := ctx.Get(ctxServiceAccountKey)
|
||||
serviceAccount, ok := controller.serviceAccountFromContext(ctx)
|
||||
if !ok {
|
||||
return responder.Code(http.StatusUnauthorized)
|
||||
}
|
||||
serviceAccount := serviceAccountUntyped.(*v1pkg.ServiceAccount)
|
||||
serviceAccountRolesSet := mapset.NewSet[v1pkg.ServiceAccountRole](serviceAccount.Roles...)
|
||||
|
||||
var authorized bool
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
||||
"github.com/cirruslabs/orchard/internal/responder"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func (controller *Controller) issueVMAccessToken(ctx *gin.Context) responder.Responder {
|
||||
if responder := controller.authorize(ctx, v1.ServiceAccountRoleComputeWrite); responder != nil {
|
||||
return responder
|
||||
}
|
||||
|
||||
var request v1.IssueVMAccessTokenRequest
|
||||
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil && !errors.Is(err, io.EOF) {
|
||||
return responder.JSON(http.StatusBadRequest, NewErrorResponse("invalid JSON was provided"))
|
||||
}
|
||||
|
||||
ttl, err := vmtempauth.NormalizeTTL(request.TTLSeconds)
|
||||
if err != nil {
|
||||
return responder.JSON(http.StatusPreconditionFailed, NewErrorResponse("%v", err))
|
||||
}
|
||||
|
||||
name := ctx.Param("name")
|
||||
var vm *v1.VM
|
||||
|
||||
if responder := controller.storeView(func(txn storepkg.Transaction) responder.Responder {
|
||||
vm, err = txn.GetVM(name)
|
||||
if err != nil {
|
||||
return responder.Error(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); responder != nil {
|
||||
return responder
|
||||
}
|
||||
|
||||
serviceAccount, ok := controller.serviceAccountFromContext(ctx)
|
||||
if !ok {
|
||||
return responder.Code(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
token, err := vmtempauth.Issue(controller.vmAccessTokenKey, vmtempauth.IssueInput{
|
||||
Issuer: vmtempauth.AccessTokenIssuer,
|
||||
Subject: serviceAccount.Name,
|
||||
VMUID: vm.UID,
|
||||
VMName: vm.Name,
|
||||
TTL: ttl,
|
||||
Now: time.Now().UTC(),
|
||||
})
|
||||
if err != nil {
|
||||
controller.logger.Errorf("failed to issue VM access token: %v", err)
|
||||
|
||||
return responder.Code(http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return responder.JSON(http.StatusOK, &v1.VMAccessToken{
|
||||
Token: token.Token,
|
||||
TokenType: "Bearer",
|
||||
ExpiresAt: token.ExpiresAt,
|
||||
SSHUsername: vmtempauth.SSHUsername,
|
||||
VMName: vm.Name,
|
||||
VMUID: vm.UID,
|
||||
})
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/responder"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"github.com/gin-gonic/gin"
|
||||
|
|
@ -15,9 +16,15 @@ import (
|
|||
)
|
||||
|
||||
func (controller *Controller) ip(ctx *gin.Context) responder.Responder {
|
||||
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
||||
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
||||
return responder
|
||||
vmAccessTokenClaims, vmAccessTokenAuth := controller.vmAccessTokenClaimsFromContext(ctx)
|
||||
|
||||
if !vmAccessTokenAuth {
|
||||
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
||||
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
||||
return responder
|
||||
}
|
||||
} else if !vmAccessTokenClaims.HasScope(vmtempauth.ScopeVMIP) {
|
||||
return responder.Code(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Retrieve and parse path and query parameters
|
||||
|
|
@ -36,6 +43,9 @@ func (controller *Controller) ip(ctx *gin.Context) responder.Responder {
|
|||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
if vmAccessTokenAuth && !vmAccessTokenClaims.CanAccessVM(vm.UID) {
|
||||
return responder.JSON(http.StatusForbidden, NewErrorResponse("the VM access token does not allow access to this VM"))
|
||||
}
|
||||
|
||||
// Send an IP resolution request and wait for the result
|
||||
session := uuid.New().String()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"github.com/cirruslabs/orchard/internal/netconncancel"
|
||||
"github.com/cirruslabs/orchard/internal/proxy"
|
||||
"github.com/cirruslabs/orchard/internal/responder"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"github.com/coder/websocket"
|
||||
|
|
@ -22,9 +23,15 @@ import (
|
|||
)
|
||||
|
||||
func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responder {
|
||||
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
||||
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
||||
return responder
|
||||
vmAccessTokenClaims, vmAccessTokenAuth := controller.vmAccessTokenClaimsFromContext(ctx)
|
||||
|
||||
if !vmAccessTokenAuth {
|
||||
if responder := controller.authorizeAny(ctx, v1.ServiceAccountRoleComputeWrite,
|
||||
v1.ServiceAccountRoleComputeConnect); responder != nil {
|
||||
return responder
|
||||
}
|
||||
} else if !vmAccessTokenClaims.HasScope(vmtempauth.ScopeVMPortForward) {
|
||||
return responder.Code(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Retrieve and parse path and query parameters
|
||||
|
|
@ -52,6 +59,9 @@ func (controller *Controller) portForwardVM(ctx *gin.Context) responder.Responde
|
|||
if responderImpl != nil {
|
||||
return responderImpl
|
||||
}
|
||||
if vmAccessTokenAuth && !vmAccessTokenClaims.CanAccessVM(vm.UID) {
|
||||
return responder.JSON(http.StatusForbidden, NewErrorResponse("the VM access token does not allow access to this VM"))
|
||||
}
|
||||
|
||||
// Commence port forwarding
|
||||
return controller.portForward(ctx, waitContext, vm.Worker, vm.UID, uint32(port))
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import (
|
|||
"github.com/cirruslabs/orchard/internal/controller/store/badger"
|
||||
"github.com/cirruslabs/orchard/internal/netconstants"
|
||||
"github.com/cirruslabs/orchard/internal/opentelemetry"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"github.com/samber/lo"
|
||||
|
|
@ -67,6 +68,7 @@ type Controller struct {
|
|||
pingInterval time.Duration
|
||||
prometheusMetrics bool
|
||||
synthetic bool
|
||||
vmAccessTokenKey []byte
|
||||
|
||||
sshListenAddr string
|
||||
sshSigner ssh.Signer
|
||||
|
|
@ -115,6 +117,10 @@ func New(opts ...Option) (*Controller, error) {
|
|||
if controller.logger == nil {
|
||||
controller.logger = zap.NewNop().Sugar()
|
||||
}
|
||||
var err error
|
||||
if controller.vmAccessTokenKey, err = controller.loadOrCreateVMAccessTokenKey(); err != nil {
|
||||
return nil, fmt.Errorf("%w: failed to initialize VM access token key: %v", ErrInitFailed, err)
|
||||
}
|
||||
|
||||
// Instantiate the database
|
||||
store, err := badger.NewBadgerStore(controller.dataDir.DBPath(), controller.disableDBCompression,
|
||||
|
|
@ -137,7 +143,8 @@ func New(opts ...Option) (*Controller, error) {
|
|||
// Instantiate the SSH server (if configured)
|
||||
if controller.sshListenAddr != "" && controller.sshSigner != nil {
|
||||
controller.sshServer, err = sshserver.NewSSHServer(controller.sshListenAddr, controller.sshSigner,
|
||||
store, controller.connRendezvous, controller.workerNotifier, controller.sshNoClientAuth, controller.logger)
|
||||
store, controller.connRendezvous, controller.workerNotifier, controller.vmAccessTokenKey,
|
||||
controller.sshNoClientAuth, controller.logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -197,6 +204,32 @@ func New(opts ...Option) (*Controller, error) {
|
|||
return controller, nil
|
||||
}
|
||||
|
||||
func (controller *Controller) loadOrCreateVMAccessTokenKey() ([]byte, error) {
|
||||
signingKey, err := controller.dataDir.VMAccessTokenSigningKey()
|
||||
if err == nil {
|
||||
if len(signingKey) != vmtempauth.SigningKeySizeBytes {
|
||||
return nil, fmt.Errorf("%w: expected %d bytes, got %d",
|
||||
vmtempauth.ErrInvalidSigningKey, vmtempauth.SigningKeySizeBytes, len(signingKey))
|
||||
}
|
||||
|
||||
return signingKey, nil
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
signingKey, err = vmtempauth.NewSigningKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := controller.dataDir.SetVMAccessTokenSigningKey(signingKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return signingKey, nil
|
||||
}
|
||||
|
||||
func (controller *Controller) ServiceAccounts() ([]v1.ServiceAccount, error) {
|
||||
var serviceAccounts []v1.ServiceAccount
|
||||
var err error
|
||||
|
|
|
|||
|
|
@ -99,6 +99,18 @@ func (dataDir *DataDir) SSHHostKeyPath() string {
|
|||
return filepath.Join(dataDir.path, "ssh_host_ed25519_key")
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) VMAccessTokenSigningKeyPath() string {
|
||||
return filepath.Join(dataDir.path, "vm_access_token_signing_key")
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) VMAccessTokenSigningKey() ([]byte, error) {
|
||||
return os.ReadFile(dataDir.VMAccessTokenSigningKeyPath())
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) SetVMAccessTokenSigningKey(signingKey []byte) error {
|
||||
return os.WriteFile(dataDir.VMAccessTokenSigningKeyPath(), signingKey, 0600)
|
||||
}
|
||||
|
||||
func (dataDir *DataDir) Initialized() (bool, error) {
|
||||
dataDirEntries, err := os.ReadDir(dataDir.path)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
|
||||
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
||||
"github.com/cirruslabs/orchard/internal/proxy"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
"github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/cirruslabs/orchard/rpc"
|
||||
"github.com/google/uuid"
|
||||
|
|
@ -27,15 +28,21 @@ const (
|
|||
//
|
||||
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
|
||||
channelTypeDirectTCPIP = "direct-tcpip"
|
||||
|
||||
permissionPrincipalTypeExt = "orchard-principal-type"
|
||||
permissionVMUIDExt = "orchard-vm-uid"
|
||||
|
||||
principalTypeVMAccessToken = "vm-access-token"
|
||||
)
|
||||
|
||||
type SSHServer struct {
|
||||
listener net.Listener
|
||||
serverConfig *ssh.ServerConfig
|
||||
store storepkg.Store
|
||||
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]]
|
||||
workerNotifier *notifier.Notifier
|
||||
logger *zap.SugaredLogger
|
||||
listener net.Listener
|
||||
serverConfig *ssh.ServerConfig
|
||||
store storepkg.Store
|
||||
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]]
|
||||
workerNotifier *notifier.Notifier
|
||||
vmAccessTokenKey []byte
|
||||
logger *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewSSHServer(
|
||||
|
|
@ -44,14 +51,16 @@ func NewSSHServer(
|
|||
store storepkg.Store,
|
||||
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]],
|
||||
workerNotifier *notifier.Notifier,
|
||||
vmAccessTokenKey []byte,
|
||||
noClientAuth bool,
|
||||
logger *zap.SugaredLogger,
|
||||
) (*SSHServer, error) {
|
||||
server := &SSHServer{
|
||||
store: store,
|
||||
connRendezvous: connRendezvous,
|
||||
workerNotifier: workerNotifier,
|
||||
logger: logger,
|
||||
store: store,
|
||||
connRendezvous: connRendezvous,
|
||||
workerNotifier: workerNotifier,
|
||||
vmAccessTokenKey: vmAccessTokenKey,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", address)
|
||||
|
|
@ -87,6 +96,24 @@ func (server *SSHServer) Address() string {
|
|||
}
|
||||
|
||||
func (server *SSHServer) passwordCallback(connMetadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if connMetadata.User() == vmtempauth.SSHUsername {
|
||||
claims, err := vmtempauth.Verify(server.vmAccessTokenKey, string(password), time.Now().UTC())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authentication failed for user %q: invalid VM access token", connMetadata.User())
|
||||
}
|
||||
if !claims.HasScope(vmtempauth.ScopeVMSSHJumpbox) {
|
||||
return nil, fmt.Errorf("authorization failed for user %q: token lacks %q scope",
|
||||
connMetadata.User(), vmtempauth.ScopeVMSSHJumpbox)
|
||||
}
|
||||
|
||||
return &ssh.Permissions{
|
||||
Extensions: map[string]string{
|
||||
permissionPrincipalTypeExt: principalTypeVMAccessToken,
|
||||
permissionVMUIDExt: claims.VMUID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if err := server.store.View(func(txn storepkg.Transaction) error {
|
||||
// Authenticate
|
||||
server.logger.Debugf("authenticating user %q using the password authentication",
|
||||
|
|
@ -157,7 +184,7 @@ func (server *SSHServer) handleConnection(conn net.Conn) {
|
|||
server.logger.Debugf("handling a new direct TCP/IP channel for user %q connecting from %q",
|
||||
sshConn.User(), sshConn.RemoteAddr().String())
|
||||
|
||||
go server.handleDirectTCPIP(connCtx, newChannel)
|
||||
go server.handleDirectTCPIP(connCtx, sshConn, newChannel)
|
||||
default:
|
||||
message := fmt.Sprintf("unsupported channel type requested: %q", newChannel.ChannelType())
|
||||
|
||||
|
|
@ -188,7 +215,7 @@ func (server *SSHServer) handleConnection(conn net.Conn) {
|
|||
}
|
||||
}
|
||||
|
||||
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.NewChannel) {
|
||||
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, sshConn *ssh.ServerConn, newChannel ssh.NewChannel) {
|
||||
// Unmarshal the payload to determine to which VM the user wants to connect to
|
||||
//
|
||||
// This direct TCP/IP channel's payload is documented
|
||||
|
|
@ -234,6 +261,20 @@ func (server *SSHServer) handleDirectTCPIP(ctx context.Context, newChannel ssh.N
|
|||
return
|
||||
}
|
||||
|
||||
if sshConn.Permissions != nil &&
|
||||
sshConn.Permissions.Extensions[permissionPrincipalTypeExt] == principalTypeVMAccessToken {
|
||||
tokenVMUID := sshConn.Permissions.Extensions[permissionVMUIDExt]
|
||||
if tokenVMUID == "" || tokenVMUID != vm.UID {
|
||||
message := "authorization failed for requested VM"
|
||||
|
||||
if err := newChannel.Reject(ssh.Prohibited, message); err != nil {
|
||||
server.logger.Warnf("failed to reject the new channel due to VM access token mismatch: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// The user wants to connect to an existing VM, request and wait
|
||||
// for a connection with the worker before accepting the channel
|
||||
session := uuid.New().String()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,337 @@
|
|||
package tests_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/controller"
|
||||
"github.com/cirruslabs/orchard/internal/echoserver"
|
||||
"github.com/cirruslabs/orchard/internal/imageconstant"
|
||||
"github.com/cirruslabs/orchard/internal/tests/wait"
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
workerpkg "github.com/cirruslabs/orchard/internal/worker"
|
||||
"github.com/cirruslabs/orchard/pkg/client"
|
||||
v1 "github.com/cirruslabs/orchard/pkg/resource/v1"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ed25519"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type authenticatedTestEnvironment struct {
|
||||
adminClient *client.Client
|
||||
controller *controller.Controller
|
||||
worker *workerpkg.Worker
|
||||
}
|
||||
|
||||
func TestVMAccessTokenAPI(t *testing.T) {
|
||||
env := startAuthenticatedTestEnvironment(t, false)
|
||||
defer func() {
|
||||
_ = env.worker.Close()
|
||||
}()
|
||||
|
||||
vmA := createRunningSyntheticVM(t, env.adminClient, "test-vm-a")
|
||||
vmB := createRunningSyntheticVM(t, env.adminClient, "test-vm-b")
|
||||
|
||||
echoServer, err := echoserver.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
echoContext, echoCancel := context.WithCancel(context.Background())
|
||||
defer echoCancel()
|
||||
|
||||
go func() {
|
||||
_ = echoServer.Run(echoContext)
|
||||
}()
|
||||
|
||||
echoPort := parsePort(t, echoServer.Addr())
|
||||
|
||||
tokenA := issueAccessToken(t, env.adminClient, vmA.Name, nil)
|
||||
|
||||
tokenClient, err := client.New(client.WithAddress(env.controller.Address()), client.WithBearerToken(tokenA.Token))
|
||||
require.NoError(t, err)
|
||||
|
||||
ip, err := tokenClient.VMs().IP(context.Background(), vmA.Name, 30)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, ip)
|
||||
|
||||
vmAConn, err := tokenClient.VMs().PortForward(context.Background(), vmA.Name, echoPort, 30)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = vmAConn.Close()
|
||||
})
|
||||
|
||||
require.NoError(t, vmAConn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
_, err = vmAConn.Write([]byte("hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
result := make([]byte, len("hello"))
|
||||
_, err = io.ReadFull(vmAConn, result)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", string(result))
|
||||
|
||||
_, err = tokenClient.VMs().IP(context.Background(), vmB.Name, 30)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, client.ErrAPI)
|
||||
require.Contains(t, err.Error(), "403")
|
||||
|
||||
_, err = tokenClient.VMs().List(context.Background())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, client.ErrAPI)
|
||||
require.Contains(t, err.Error(), "401")
|
||||
}
|
||||
|
||||
func TestVMAccessTokenVMRecreation(t *testing.T) {
|
||||
env := startAuthenticatedTestEnvironment(t, false)
|
||||
defer func() {
|
||||
_ = env.worker.Close()
|
||||
}()
|
||||
|
||||
vm := createRunningSyntheticVM(t, env.adminClient, "test-vm")
|
||||
|
||||
token := issueAccessToken(t, env.adminClient, vm.Name, nil)
|
||||
|
||||
tokenClient, err := client.New(client.WithAddress(env.controller.Address()), client.WithBearerToken(token.Token))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, env.adminClient.VMs().Delete(context.Background(), vm.Name))
|
||||
|
||||
createRunningSyntheticVM(t, env.adminClient, vm.Name)
|
||||
|
||||
_, err = tokenClient.VMs().IP(context.Background(), vm.Name, 30)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, client.ErrAPI)
|
||||
require.Contains(t, err.Error(), "403")
|
||||
}
|
||||
|
||||
func TestVMAccessTokenSSHServer(t *testing.T) {
|
||||
env := startAuthenticatedTestEnvironment(t, true)
|
||||
defer func() {
|
||||
_ = env.worker.Close()
|
||||
}()
|
||||
|
||||
vmA := createRunningSyntheticVM(t, env.adminClient, "test-vm-a")
|
||||
vmB := createRunningSyntheticVM(t, env.adminClient, "test-vm-b")
|
||||
|
||||
echoServer, err := echoserver.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
echoContext, echoCancel := context.WithCancel(context.Background())
|
||||
defer echoCancel()
|
||||
|
||||
go func() {
|
||||
_ = echoServer.Run(echoContext)
|
||||
}()
|
||||
|
||||
echoPort := parsePort(t, echoServer.Addr())
|
||||
|
||||
tokenA := issueAccessToken(t, env.adminClient, vmA.Name, nil)
|
||||
|
||||
sshAddress, ok := env.controller.SSHAddress()
|
||||
require.True(t, ok)
|
||||
|
||||
sshClientController, err := ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{
|
||||
User: vmtempauth.SSHUsername,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(tokenA.Token),
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sshClientController.Close()
|
||||
})
|
||||
|
||||
vmAConn, err := sshClientController.Dial("tcp", fmt.Sprintf("%s:%d", vmA.Name, echoPort))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = vmAConn.Close()
|
||||
})
|
||||
|
||||
_, err = vmAConn.Write([]byte("jumpbox"))
|
||||
require.NoError(t, err)
|
||||
|
||||
vmAResult := make([]byte, len("jumpbox"))
|
||||
_, err = io.ReadFull(vmAConn, vmAResult)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "jumpbox", string(vmAResult))
|
||||
|
||||
_, err = sshClientController.Dial("tcp", fmt.Sprintf("%s:%d", vmB.Name, echoPort))
|
||||
require.Error(t, err)
|
||||
|
||||
shortTTL := uint64(1)
|
||||
shortToken := issueAccessToken(t, env.adminClient, vmA.Name, &shortTTL)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
_, err = ssh.Dial("tcp", sshAddress, &ssh.ClientConfig{
|
||||
User: vmtempauth.SSHUsername,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(shortToken.Token),
|
||||
},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func startAuthenticatedTestEnvironment(
|
||||
t *testing.T,
|
||||
withSSHServer bool,
|
||||
) *authenticatedTestEnvironment {
|
||||
t.Helper()
|
||||
|
||||
dataDir, err := controller.NewDataDir(t.TempDir())
|
||||
require.NoError(t, err)
|
||||
|
||||
controllerOpts := []controller.Option{
|
||||
controller.WithDataDir(dataDir),
|
||||
controller.WithListenAddr(":0"),
|
||||
controller.WithExperimentalRPCV2(),
|
||||
}
|
||||
|
||||
if withSSHServer {
|
||||
_, privateKey, err := ed25519.GenerateKey(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
controllerOpts = append(controllerOpts, controller.WithSSHServer(":0", signer, false))
|
||||
}
|
||||
|
||||
controllerInstance, err := controller.New(controllerOpts...)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, controllerInstance.EnsureServiceAccount(&v1.ServiceAccount{
|
||||
Meta: v1.Meta{
|
||||
Name: "admin",
|
||||
},
|
||||
Token: "admin-token",
|
||||
Roles: v1.AllServiceAccountRoles(),
|
||||
}))
|
||||
require.NoError(t, controllerInstance.EnsureServiceAccount(&v1.ServiceAccount{
|
||||
Meta: v1.Meta{
|
||||
Name: "worker",
|
||||
},
|
||||
Token: "worker-token",
|
||||
Roles: []v1.ServiceAccountRole{
|
||||
v1.ServiceAccountRoleComputeRead,
|
||||
v1.ServiceAccountRoleComputeWrite,
|
||||
},
|
||||
}))
|
||||
|
||||
adminClient, err := client.New(
|
||||
client.WithAddress(controllerInstance.Address()),
|
||||
client.WithCredentials("admin", "admin-token"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
workerClient, err := client.New(
|
||||
client.WithAddress(controllerInstance.Address()),
|
||||
client.WithCredentials("worker", "worker-token"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
workerInstance, err := workerpkg.New(workerClient, workerpkg.WithName("worker-a"), workerpkg.WithSynthetic())
|
||||
require.NoError(t, err)
|
||||
|
||||
testContext, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
go func() {
|
||||
runErr := controllerInstance.Run(testContext)
|
||||
if runErr != nil && !errors.Is(runErr, context.Canceled) && !errors.Is(runErr, http.ErrServerClosed) {
|
||||
t.Errorf("controller failed: %v", runErr)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
runErr := workerInstance.Run(testContext)
|
||||
if runErr != nil && !errors.Is(runErr, context.Canceled) {
|
||||
t.Errorf("worker failed: %v", runErr)
|
||||
}
|
||||
}()
|
||||
|
||||
assert.True(t, wait.Wait(30*time.Second, func() bool {
|
||||
workers, err := adminClient.Workers().List(context.Background())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return len(workers) == 1
|
||||
}), "failed to wait for worker to register")
|
||||
|
||||
return &authenticatedTestEnvironment{
|
||||
adminClient: adminClient,
|
||||
controller: controllerInstance,
|
||||
worker: workerInstance,
|
||||
}
|
||||
}
|
||||
|
||||
func createRunningSyntheticVM(t *testing.T, apiClient *client.Client, name string) *v1.VM {
|
||||
t.Helper()
|
||||
|
||||
err := apiClient.VMs().Create(context.Background(), &v1.VM{
|
||||
Meta: v1.Meta{
|
||||
Name: name,
|
||||
},
|
||||
Image: imageconstant.DefaultMacosImage,
|
||||
CPU: 1,
|
||||
Memory: 512,
|
||||
Headless: true,
|
||||
Status: v1.VMStatusPending,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Truef(t, wait.Wait(60*time.Second, func() bool {
|
||||
vm, err := apiClient.VMs().Get(context.Background(), name)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return vm.Status == v1.VMStatusRunning
|
||||
}), "failed to wait for VM %q to reach running state", name)
|
||||
|
||||
vm, err := apiClient.VMs().Get(context.Background(), name)
|
||||
require.NoError(t, err)
|
||||
|
||||
return vm
|
||||
}
|
||||
|
||||
func issueAccessToken(
|
||||
t *testing.T,
|
||||
apiClient *client.Client,
|
||||
vmName string,
|
||||
ttlSeconds *uint64,
|
||||
) *v1.VMAccessToken {
|
||||
t.Helper()
|
||||
|
||||
token, err := apiClient.VMs().IssueAccessToken(context.Background(), vmName, client.IssueAccessTokenOptions{
|
||||
TTLSeconds: ttlSeconds,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token.Token)
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
func parsePort(t *testing.T, addr string) uint16 {
|
||||
t.Helper()
|
||||
|
||||
_, portRaw, err := net.SplitHostPort(addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
port, err := strconv.ParseUint(portRaw, 10, 16)
|
||||
require.NoError(t, err)
|
||||
|
||||
return uint16(port)
|
||||
}
|
||||
|
|
@ -0,0 +1,306 @@
|
|||
package vmtempauth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
SigningKeySizeBytes = 32
|
||||
|
||||
JWTHeaderAlgHS256 = "HS256"
|
||||
JWTHeaderTypJWT = "JWT"
|
||||
|
||||
AccessTokenAudience = "orchard-vm-access"
|
||||
AccessTokenIssuer = "orchard-controller"
|
||||
|
||||
ScopeVMPortForward = "vm:port-forward"
|
||||
ScopeVMIP = "vm:ip"
|
||||
ScopeVMSSHJumpbox = "vm:ssh-jumpbox"
|
||||
|
||||
PortsAny = "*"
|
||||
|
||||
DefaultTTL = 24 * time.Hour
|
||||
MaxTTL = 30 * 24 * time.Hour
|
||||
|
||||
SSHUsername = "token"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidSigningKey = errors.New("invalid signing key")
|
||||
ErrMalformedToken = errors.New("malformed token")
|
||||
ErrInvalidTokenHeader = errors.New("invalid token header")
|
||||
ErrInvalidTokenClaims = errors.New("invalid token claims")
|
||||
ErrSignatureMismatch = errors.New("token signature mismatch")
|
||||
ErrTokenExpired = errors.New("token expired")
|
||||
ErrTokenNotYetValid = errors.New("token not yet valid")
|
||||
ErrInvalidTTL = errors.New("invalid access token TTL")
|
||||
encoding = base64.RawURLEncoding
|
||||
requiredScopes = []string{ScopeVMPortForward, ScopeVMIP, ScopeVMSSHJumpbox}
|
||||
)
|
||||
|
||||
type header struct {
|
||||
Alg string `json:"alg"`
|
||||
Typ string `json:"typ"`
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Audience []string `json:"aud"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
NotBefore int64 `json:"nbf"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
JTI string `json:"jti"`
|
||||
|
||||
VMUID string `json:"vm_uid"`
|
||||
VMName string `json:"vm_name,omitempty"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Ports string `json:"ports,omitempty"`
|
||||
}
|
||||
|
||||
func (claims Claims) HasScope(scope string) bool {
|
||||
return slices.Contains(claims.Scopes, scope)
|
||||
}
|
||||
|
||||
func (claims Claims) CanAccessVM(vmUID string) bool {
|
||||
return claims.VMUID != "" && vmUID != "" && claims.VMUID == vmUID
|
||||
}
|
||||
|
||||
type IssueInput struct {
|
||||
Issuer string
|
||||
Subject string
|
||||
VMUID string
|
||||
VMName string
|
||||
TTL time.Duration
|
||||
Now time.Time
|
||||
}
|
||||
|
||||
type IssueOutput struct {
|
||||
Token string
|
||||
Claims Claims
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func NewSigningKey() ([]byte, error) {
|
||||
key := make([]byte, SigningKeySizeBytes)
|
||||
|
||||
if _, err := rand.Read(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func NormalizeTTL(ttlSeconds *uint64) (time.Duration, error) {
|
||||
if ttlSeconds == nil {
|
||||
return DefaultTTL, nil
|
||||
}
|
||||
|
||||
if *ttlSeconds == 0 {
|
||||
return 0, fmt.Errorf("%w: TTL cannot be zero", ErrInvalidTTL)
|
||||
}
|
||||
|
||||
ttl := time.Duration(*ttlSeconds) * time.Second
|
||||
|
||||
if ttl > MaxTTL {
|
||||
return 0, fmt.Errorf("%w: maximum allowed TTL is %s", ErrInvalidTTL, MaxTTL)
|
||||
}
|
||||
|
||||
return ttl, nil
|
||||
}
|
||||
|
||||
func Issue(signingKey []byte, input IssueInput) (*IssueOutput, error) {
|
||||
if len(signingKey) != SigningKeySizeBytes {
|
||||
return nil, ErrInvalidSigningKey
|
||||
}
|
||||
|
||||
if input.Subject == "" {
|
||||
return nil, fmt.Errorf("%w: missing subject", ErrInvalidTokenClaims)
|
||||
}
|
||||
if input.VMUID == "" {
|
||||
return nil, fmt.Errorf("%w: missing vm_uid", ErrInvalidTokenClaims)
|
||||
}
|
||||
|
||||
ttl := input.TTL
|
||||
if ttl == 0 {
|
||||
ttl = DefaultTTL
|
||||
}
|
||||
if ttl < 0 || ttl > MaxTTL {
|
||||
return nil, fmt.Errorf("%w: unsupported TTL %s", ErrInvalidTTL, ttl)
|
||||
}
|
||||
|
||||
now := input.Now
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
issuer := input.Issuer
|
||||
if issuer == "" {
|
||||
issuer = AccessTokenIssuer
|
||||
}
|
||||
|
||||
claims := Claims{
|
||||
Issuer: issuer,
|
||||
Subject: input.Subject,
|
||||
Audience: []string{AccessTokenAudience},
|
||||
IssuedAt: now.Unix(),
|
||||
NotBefore: now.Unix(),
|
||||
ExpiresAt: now.Add(ttl).Unix(),
|
||||
JTI: uuid.NewString(),
|
||||
VMUID: input.VMUID,
|
||||
VMName: input.VMName,
|
||||
Scopes: append([]string{}, requiredScopes...),
|
||||
Ports: PortsAny,
|
||||
}
|
||||
|
||||
token, err := encodeAndSign(signingKey, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &IssueOutput{
|
||||
Token: token,
|
||||
Claims: claims,
|
||||
ExpiresAt: time.Unix(claims.ExpiresAt, 0).UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func Verify(signingKey []byte, token string, now time.Time) (*Claims, error) {
|
||||
if len(signingKey) != SigningKeySizeBytes {
|
||||
return nil, ErrInvalidSigningKey
|
||||
}
|
||||
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
splits := strings.Split(token, ".")
|
||||
if len(splits) != 3 {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
headerRaw, err := encoding.DecodeString(splits[0])
|
||||
if err != nil {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
var tokenHeader header
|
||||
|
||||
if err := json.Unmarshal(headerRaw, &tokenHeader); err != nil {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
if tokenHeader.Alg != JWTHeaderAlgHS256 || tokenHeader.Typ != JWTHeaderTypJWT {
|
||||
return nil, ErrInvalidTokenHeader
|
||||
}
|
||||
|
||||
signedPart := fmt.Sprintf("%s.%s", splits[0], splits[1])
|
||||
expectedSignature := sign(signingKey, signedPart)
|
||||
|
||||
presentedSignature, err := encoding.DecodeString(splits[2])
|
||||
if err != nil {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
if subtle.ConstantTimeCompare(expectedSignature, presentedSignature) != 1 {
|
||||
return nil, ErrSignatureMismatch
|
||||
}
|
||||
|
||||
claimsRaw, err := encoding.DecodeString(splits[1])
|
||||
if err != nil {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
var claims Claims
|
||||
|
||||
if err := json.Unmarshal(claimsRaw, &claims); err != nil {
|
||||
return nil, ErrMalformedToken
|
||||
}
|
||||
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
if err := validateClaims(claims, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
func encodeAndSign(signingKey []byte, claims Claims) (string, error) {
|
||||
headerJSON, err := json.Marshal(header{
|
||||
Alg: JWTHeaderAlgHS256,
|
||||
Typ: JWTHeaderTypJWT,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claimsJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedPart := fmt.Sprintf("%s.%s",
|
||||
encoding.EncodeToString(headerJSON),
|
||||
encoding.EncodeToString(claimsJSON),
|
||||
)
|
||||
|
||||
signature := sign(signingKey, signedPart)
|
||||
|
||||
return fmt.Sprintf("%s.%s", signedPart, encoding.EncodeToString(signature)), nil
|
||||
}
|
||||
|
||||
func sign(signingKey []byte, signedPart string) []byte {
|
||||
mac := hmac.New(sha256.New, signingKey)
|
||||
_, _ = mac.Write([]byte(signedPart))
|
||||
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func validateClaims(claims Claims, now time.Time) error {
|
||||
if claims.Issuer == "" || claims.Subject == "" || claims.JTI == "" || claims.VMUID == "" {
|
||||
return ErrInvalidTokenClaims
|
||||
}
|
||||
|
||||
if claims.IssuedAt == 0 || claims.NotBefore == 0 || claims.ExpiresAt == 0 {
|
||||
return ErrInvalidTokenClaims
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, AccessTokenAudience) {
|
||||
return ErrInvalidTokenClaims
|
||||
}
|
||||
|
||||
if claims.NotBefore > now.Unix() {
|
||||
return ErrTokenNotYetValid
|
||||
}
|
||||
if claims.ExpiresAt <= now.Unix() {
|
||||
return ErrTokenExpired
|
||||
}
|
||||
|
||||
for _, requiredScope := range requiredScopes {
|
||||
if !claims.HasScope(requiredScope) {
|
||||
return ErrInvalidTokenClaims
|
||||
}
|
||||
}
|
||||
|
||||
if claims.Ports != "" && claims.Ports != PortsAny {
|
||||
return ErrInvalidTokenClaims
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
package vmtempauth_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIssueAndVerify(t *testing.T) {
|
||||
signingKey, err := vmtempauth.NewSigningKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Unix(1735779600, 0).UTC()
|
||||
|
||||
issued, err := vmtempauth.Issue(signingKey, vmtempauth.IssueInput{
|
||||
Subject: "issuer",
|
||||
VMUID: "vm-uid",
|
||||
VMName: "vm-name",
|
||||
TTL: 10 * time.Minute,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := vmtempauth.Verify(signingKey, issued.Token, now.Add(time.Minute))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "issuer", claims.Subject)
|
||||
require.Equal(t, "vm-uid", claims.VMUID)
|
||||
require.Equal(t, "vm-name", claims.VMName)
|
||||
require.True(t, claims.HasScope(vmtempauth.ScopeVMPortForward))
|
||||
require.True(t, claims.HasScope(vmtempauth.ScopeVMIP))
|
||||
require.True(t, claims.HasScope(vmtempauth.ScopeVMSSHJumpbox))
|
||||
require.True(t, claims.CanAccessVM("vm-uid"))
|
||||
}
|
||||
|
||||
func TestVerifyExpired(t *testing.T) {
|
||||
signingKey, err := vmtempauth.NewSigningKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Unix(1735779600, 0).UTC()
|
||||
issued, err := vmtempauth.Issue(signingKey, vmtempauth.IssueInput{
|
||||
Subject: "issuer",
|
||||
VMUID: "vm-uid",
|
||||
TTL: time.Second,
|
||||
Now: now,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = vmtempauth.Verify(signingKey, issued.Token, now.Add(2*time.Second))
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, vmtempauth.ErrTokenExpired)
|
||||
}
|
||||
|
||||
func TestVerifyBadSignature(t *testing.T) {
|
||||
signingKey, err := vmtempauth.NewSigningKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
issued, err := vmtempauth.Issue(signingKey, vmtempauth.IssueInput{
|
||||
Subject: "issuer",
|
||||
VMUID: "vm-uid",
|
||||
TTL: time.Minute,
|
||||
Now: time.Now().UTC(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tampered := issued.Token[:len(issued.Token)-1] + "x"
|
||||
|
||||
_, err = vmtempauth.Verify(signingKey, tampered, time.Now().UTC())
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, vmtempauth.ErrSignatureMismatch)
|
||||
}
|
||||
|
||||
func TestNormalizeTTL(t *testing.T) {
|
||||
defaultTTL, err := vmtempauth.NormalizeTTL(nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, vmtempauth.DefaultTTL, defaultTTL)
|
||||
|
||||
zero := uint64(0)
|
||||
_, err = vmtempauth.NormalizeTTL(&zero)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, vmtempauth.ErrInvalidTTL))
|
||||
|
||||
tooLong := uint64(vmtempauth.MaxTTL/time.Second) + 1
|
||||
_, err = vmtempauth.NormalizeTTL(&tooLong)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, vmtempauth.ErrInvalidTTL))
|
||||
}
|
||||
|
|
@ -42,6 +42,7 @@ type Client struct {
|
|||
|
||||
serviceAccountName string
|
||||
serviceAccountToken string
|
||||
bearerToken string
|
||||
|
||||
dialer dialer.Dialer
|
||||
}
|
||||
|
|
@ -321,6 +322,12 @@ func (client *Client) formatPath(path string) *url.URL {
|
|||
func (client *Client) modifyHeader(header http.Header) {
|
||||
header.Set("User-Agent", fmt.Sprintf("Orchard/%s", version.FullVersion))
|
||||
|
||||
if client.bearerToken != "" {
|
||||
header.Set("Authorization", fmt.Sprintf("Bearer %s", client.bearerToken))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if client.serviceAccountName != "" && client.serviceAccountToken != "" {
|
||||
authPlain := fmt.Sprintf("%s:%s", client.serviceAccountName, client.serviceAccountToken)
|
||||
authEncoded := base64.StdEncoding.EncodeToString([]byte(authPlain))
|
||||
|
|
|
|||
|
|
@ -30,6 +30,12 @@ func WithCredentials(serviceAccountName string, serviceAccountToken string) Opti
|
|||
}
|
||||
}
|
||||
|
||||
func WithBearerToken(token string) Option {
|
||||
return func(client *Client) {
|
||||
client.bearerToken = token
|
||||
}
|
||||
}
|
||||
|
||||
func WithDialer(dialer dialer.Dialer) Option {
|
||||
return func(client *Client) {
|
||||
client.dialer = dialer
|
||||
|
|
|
|||
|
|
@ -46,6 +46,10 @@ type EventsPageOptions struct {
|
|||
Cursor string
|
||||
}
|
||||
|
||||
type IssueAccessTokenOptions struct {
|
||||
TTLSeconds *uint64
|
||||
}
|
||||
|
||||
func (service *VMsService) Create(ctx context.Context, vm *v1.VM) error {
|
||||
err := service.client.request(ctx, http.MethodPost, "vms",
|
||||
vm, nil, nil)
|
||||
|
|
@ -172,6 +176,26 @@ func (service *VMsService) IP(ctx context.Context, name string, waitSeconds uint
|
|||
return result.IP, nil
|
||||
}
|
||||
|
||||
func (service *VMsService) IssueAccessToken(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
options IssueAccessTokenOptions,
|
||||
) (*v1.VMAccessToken, error) {
|
||||
request := v1.IssueVMAccessTokenRequest{
|
||||
TTLSeconds: options.TTLSeconds,
|
||||
}
|
||||
|
||||
var token v1.VMAccessToken
|
||||
|
||||
err := service.client.request(ctx, http.MethodPost, fmt.Sprintf("vms/%s/access-tokens", url.PathEscape(name)),
|
||||
request, &token, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (service *VMsService) StreamEvents(name string) *EventStreamer {
|
||||
return NewEventStreamer(service.client, fmt.Sprintf("vms/%s/events", url.PathEscape(name)))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
package v1
|
||||
|
||||
import "time"
|
||||
|
||||
type IssueVMAccessTokenRequest struct {
|
||||
TTLSeconds *uint64 `json:"ttlSeconds,omitempty"`
|
||||
}
|
||||
|
||||
type VMAccessToken struct {
|
||||
Token string `json:"token,omitempty"`
|
||||
TokenType string `json:"tokenType,omitempty"`
|
||||
ExpiresAt time.Time `json:"expiresAt,omitempty"`
|
||||
SSHUsername string `json:"sshUsername,omitempty"`
|
||||
VMName string `json:"vmName,omitempty"`
|
||||
VMUID string `json:"vmUID,omitempty"`
|
||||
}
|
||||
Loading…
Reference in New Issue