347 lines
10 KiB
Go
347 lines
10 KiB
Go
package sshserver
|
|
|
|
import (
|
|
"context"
|
|
"crypto/subtle"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/cirruslabs/orchard/internal/controller/notifier"
|
|
"github.com/cirruslabs/orchard/internal/controller/rendezvous"
|
|
storepkg "github.com/cirruslabs/orchard/internal/controller/store"
|
|
"github.com/cirruslabs/orchard/internal/proxy"
|
|
"github.com/cirruslabs/orchard/internal/vmtempauth"
|
|
"github.com/cirruslabs/orchard/pkg/resource/v1"
|
|
"github.com/cirruslabs/orchard/rpc"
|
|
"github.com/google/uuid"
|
|
"github.com/samber/lo"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
// "ssh -J" uses channels of type "direct-tcpip", which are documented
|
|
// in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1].
|
|
//
|
|
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
|
|
channelTypeDirectTCPIP = "direct-tcpip"
|
|
|
|
permissionPrincipalTypeExt = "orchard-principal-type"
|
|
permissionVMUIDExt = "orchard-vm-uid"
|
|
|
|
principalTypeVMAccessToken = "vm-access-token"
|
|
)
|
|
|
|
type SSHServer struct {
|
|
listener net.Listener
|
|
serverConfig *ssh.ServerConfig
|
|
store storepkg.Store
|
|
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]]
|
|
workerNotifier *notifier.Notifier
|
|
vmAccessTokenKey []byte
|
|
logger *zap.SugaredLogger
|
|
}
|
|
|
|
func NewSSHServer(
|
|
address string,
|
|
signer ssh.Signer,
|
|
store storepkg.Store,
|
|
connRendezvous *rendezvous.Rendezvous[rendezvous.ResultWithErrorMessage[net.Conn]],
|
|
workerNotifier *notifier.Notifier,
|
|
vmAccessTokenKey []byte,
|
|
noClientAuth bool,
|
|
logger *zap.SugaredLogger,
|
|
) (*SSHServer, error) {
|
|
server := &SSHServer{
|
|
store: store,
|
|
connRendezvous: connRendezvous,
|
|
workerNotifier: workerNotifier,
|
|
vmAccessTokenKey: vmAccessTokenKey,
|
|
logger: logger,
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
server.listener = listener
|
|
|
|
server.serverConfig = &ssh.ServerConfig{
|
|
NoClientAuth: noClientAuth,
|
|
PasswordCallback: server.passwordCallback,
|
|
}
|
|
server.serverConfig.AddHostKey(signer)
|
|
|
|
return server, nil
|
|
}
|
|
|
|
func (server *SSHServer) Run() {
|
|
for {
|
|
conn, err := server.listener.Accept()
|
|
if err != nil {
|
|
server.logger.Warnf("failed to accept connection: %v", err)
|
|
|
|
continue
|
|
}
|
|
|
|
go server.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
func (server *SSHServer) Address() string {
|
|
return strings.ReplaceAll(server.listener.Addr().String(), "[::]", "127.0.0.1")
|
|
}
|
|
|
|
func (server *SSHServer) passwordCallback(connMetadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
if connMetadata.User() == vmtempauth.SSHUsername {
|
|
claims, err := vmtempauth.Verify(server.vmAccessTokenKey, string(password), time.Now().UTC())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authentication failed for user %q: invalid VM access token", connMetadata.User())
|
|
}
|
|
if !claims.HasScope(vmtempauth.ScopeVMSSHJumpbox) {
|
|
return nil, fmt.Errorf("authorization failed for user %q: token lacks %q scope",
|
|
connMetadata.User(), vmtempauth.ScopeVMSSHJumpbox)
|
|
}
|
|
|
|
return &ssh.Permissions{
|
|
Extensions: map[string]string{
|
|
permissionPrincipalTypeExt: principalTypeVMAccessToken,
|
|
permissionVMUIDExt: claims.VMUID,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
if err := server.store.View(func(txn storepkg.Transaction) error {
|
|
// Authenticate
|
|
server.logger.Debugf("authenticating user %q using the password authentication",
|
|
connMetadata.User())
|
|
|
|
serviceAccount, err := txn.GetServiceAccount(connMetadata.User())
|
|
if err != nil {
|
|
if errors.Is(err, storepkg.ErrNotFound) {
|
|
return fmt.Errorf("authentication failed, non-existent user %q",
|
|
connMetadata.User())
|
|
}
|
|
|
|
server.logger.Errorf("failed to retrieve service account %q: %v",
|
|
connMetadata.User(), err)
|
|
|
|
return fmt.Errorf("authentication failed due to an internal error")
|
|
}
|
|
|
|
if subtle.ConstantTimeCompare([]byte(serviceAccount.Token), password) != 1 {
|
|
return fmt.Errorf("authentication failed for user %q: invalid password",
|
|
connMetadata.User())
|
|
}
|
|
|
|
// Authorize
|
|
authorized := lo.Contains(serviceAccount.Roles, v1.ServiceAccountRoleComputeWrite) ||
|
|
lo.Contains(serviceAccount.Roles, v1.ServiceAccountRoleComputeConnect)
|
|
|
|
if !authorized {
|
|
return fmt.Errorf("authorization failed for user %q because it lacks %q or %q roles",
|
|
connMetadata.User(), v1.ServiceAccountRoleComputeWrite, v1.ServiceAccountRoleComputeConnect)
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &ssh.Permissions{}, nil
|
|
}
|
|
|
|
func (server *SSHServer) handleConnection(conn net.Conn) {
|
|
sshConn, newChannelCh, requestCh, err := ssh.NewServerConn(conn, server.serverConfig)
|
|
if err != nil {
|
|
server.logger.Warnf("failed to instantiate the SSH server instance to handle "+
|
|
"the incoming connection from %s: %v", conn.RemoteAddr().String(), err)
|
|
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = sshConn.Close()
|
|
}()
|
|
|
|
server.logger.Debugf("accepted SSH connection for user %q connecting from %q",
|
|
sshConn.User(), sshConn.RemoteAddr().String())
|
|
|
|
connCtx, connCtxCancel := context.WithCancel(context.Background())
|
|
defer connCtxCancel()
|
|
|
|
for {
|
|
select {
|
|
case newChannel, ok := <-newChannelCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
switch newChannel.ChannelType() {
|
|
case channelTypeDirectTCPIP:
|
|
server.logger.Debugf("handling a new direct TCP/IP channel for user %q connecting from %q",
|
|
sshConn.User(), sshConn.RemoteAddr().String())
|
|
|
|
go server.handleDirectTCPIP(connCtx, sshConn, newChannel)
|
|
default:
|
|
message := fmt.Sprintf("unsupported channel type requested: %q", newChannel.ChannelType())
|
|
|
|
server.logger.Debugf(message)
|
|
|
|
if err := newChannel.Reject(ssh.UnknownChannelType, message); err != nil {
|
|
server.logger.Warnf("failed to reject new channel of unsupported type %q: %v",
|
|
newChannel.ChannelType(), err)
|
|
|
|
return
|
|
}
|
|
}
|
|
case request, ok := <-requestCh:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
server.logger.Debugf("refusing to service new request of type %q with payload of %d bytes",
|
|
request.Type, len(request.Payload))
|
|
|
|
if err := request.Reply(false, nil); err != nil {
|
|
server.logger.Warnf("failed to reply to a new request of type %q and payload of %d bytes: %v",
|
|
request.Type, len(request.Payload), err)
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (server *SSHServer) handleDirectTCPIP(ctx context.Context, sshConn *ssh.ServerConn, newChannel ssh.NewChannel) {
|
|
// Unmarshal the payload to determine to which VM the user wants to connect to
|
|
//
|
|
// This direct TCP/IP channel's payload is documented
|
|
// in the RFC 4254 (7.2. TCP/IP Forwarding Channels)[1].
|
|
//
|
|
// [1]: https://datatracker.ietf.org/doc/html/rfc4254#section-7.2
|
|
payload := struct {
|
|
HostToConnect string
|
|
PortToConnect uint32
|
|
OriginatorIPAddress string
|
|
OriginatorPort uint32
|
|
}{}
|
|
|
|
if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil {
|
|
message := fmt.Sprintf("failed to unmarshal payload: %v", err)
|
|
|
|
server.logger.Warn(message)
|
|
|
|
if err := newChannel.Reject(ssh.ConnectionFailed, message); err != nil {
|
|
server.logger.Warnf("failed to reject the new channel: %v", err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
server.logger.Debugf("proxying connection to %s:%d", payload.HostToConnect, payload.PortToConnect)
|
|
|
|
// Retrieve the VM object
|
|
var vm *v1.VM
|
|
var err error
|
|
|
|
err = server.store.View(func(txn storepkg.Transaction) error {
|
|
vm, err = txn.GetVM(payload.HostToConnect)
|
|
|
|
return err
|
|
})
|
|
if err != nil {
|
|
if err := newChannel.Reject(ssh.ConnectionFailed, "failed to find VM"); err != nil {
|
|
server.logger.Warnf("failed to reject the new channel due to non-existent VM %q: %v",
|
|
payload.HostToConnect, err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
if sshConn.Permissions != nil &&
|
|
sshConn.Permissions.Extensions[permissionPrincipalTypeExt] == principalTypeVMAccessToken {
|
|
tokenVMUID := sshConn.Permissions.Extensions[permissionVMUIDExt]
|
|
if tokenVMUID == "" || tokenVMUID != vm.UID {
|
|
message := "authorization failed for requested VM"
|
|
|
|
if err := newChannel.Reject(ssh.Prohibited, message); err != nil {
|
|
server.logger.Warnf("failed to reject the new channel due to VM access token mismatch: %v", err)
|
|
}
|
|
|
|
return
|
|
}
|
|
}
|
|
|
|
// The user wants to connect to an existing VM, request and wait
|
|
// for a connection with the worker before accepting the channel
|
|
session := uuid.New().String()
|
|
boomerangConnCh, cancel := server.connRendezvous.Request(ctx, session)
|
|
defer cancel()
|
|
|
|
notifyContext, notifyContextCancel := context.WithTimeout(ctx, time.Second)
|
|
defer notifyContextCancel()
|
|
err = server.workerNotifier.Notify(notifyContext, vm.Worker, &rpc.WatchInstruction{
|
|
Action: &rpc.WatchInstruction_PortForwardAction{
|
|
PortForwardAction: &rpc.WatchInstruction_PortForward{
|
|
Session: session,
|
|
VmUid: vm.UID,
|
|
Port: payload.PortToConnect,
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
server.logger.Warnf("failed to request port-forwarding from the worker %s: %v",
|
|
vm.Worker, err)
|
|
|
|
return
|
|
}
|
|
|
|
// Wait for the connection from worker and commence port forwarding
|
|
select {
|
|
case rendezvousResponse := <-boomerangConnCh:
|
|
if rendezvousResponse.ErrorMessage != "" {
|
|
message := fmt.Sprintf("failed to establish port forwarding session on the worker: %s",
|
|
rendezvousResponse.ErrorMessage)
|
|
|
|
if err := newChannel.Reject(ssh.ConnectionFailed, message); err != nil {
|
|
server.logger.Warnf("failed to reject new channel due to "+
|
|
"failure establishing port forwarding session on the worker: %v", err)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Now that we have the connection from worker we can accept the channel
|
|
acceptedChannel, acceptedChannelRequests, err := newChannel.Accept()
|
|
if err != nil {
|
|
server.logger.Warnf("failed to accept the new channel: %v", err)
|
|
|
|
return
|
|
}
|
|
|
|
// Handle new requests on the accepted channel by refusing them
|
|
go func() {
|
|
req, ok := <-acceptedChannelRequests
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
if err := req.Reply(false, nil); err != nil {
|
|
server.logger.Warnf("failed to reply to the new channel request: %v", err)
|
|
|
|
return
|
|
}
|
|
}()
|
|
|
|
// Commence port forwarding
|
|
if err := proxy.Connections(acceptedChannel, rendezvousResponse.Result); err != nil {
|
|
server.logger.Warnf("failed to port-forward: %v", err)
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|