mirror of https://github.com/h44z/wg-portal.git
				
				
				
			
		
			
				
	
	
		
			327 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			327 lines
		
	
	
		
			9.2 KiB
		
	
	
	
		
			Go
		
	
	
	
| package wireguard
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"log/slog"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/h44z/wg-portal/internal/app"
 | |
| 	"github.com/h44z/wg-portal/internal/config"
 | |
| 	"github.com/h44z/wg-portal/internal/domain"
 | |
| )
 | |
| 
 | |
| // region dependencies
 | |
| 
 | |
| type InterfaceAndPeerDatabaseRepo interface {
 | |
| 	GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
 | |
| 	GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error)
 | |
| 	GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error)
 | |
| 	GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
 | |
| 	GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error)
 | |
| 	SaveInterface(
 | |
| 		ctx context.Context,
 | |
| 		id domain.InterfaceIdentifier,
 | |
| 		updateFunc func(in *domain.Interface) (*domain.Interface, error),
 | |
| 	) error
 | |
| 	DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error
 | |
| 	GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
 | |
| 	GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
 | |
| 	SavePeer(
 | |
| 		ctx context.Context,
 | |
| 		id domain.PeerIdentifier,
 | |
| 		updateFunc func(in *domain.Peer) (*domain.Peer, error),
 | |
| 	) error
 | |
| 	DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
 | |
| 	GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error)
 | |
| 	GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error)
 | |
| }
 | |
| 
 | |
| type InterfaceController interface {
 | |
| 	GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
 | |
| 	GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
 | |
| 	GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
 | |
| 	SaveInterface(
 | |
| 		_ context.Context,
 | |
| 		id domain.InterfaceIdentifier,
 | |
| 		updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
 | |
| 	) error
 | |
| 	DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
 | |
| 	SavePeer(
 | |
| 		_ context.Context,
 | |
| 		deviceId domain.InterfaceIdentifier,
 | |
| 		id domain.PeerIdentifier,
 | |
| 		updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
 | |
| 	) error
 | |
| 	DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
 | |
| }
 | |
| 
 | |
| type WgQuickController interface {
 | |
| 	ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
 | |
| 	SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
 | |
| 	UnsetDNS(id domain.InterfaceIdentifier) error
 | |
| }
 | |
| 
 | |
| type EventBus interface {
 | |
| 	// Publish sends a message to the message bus.
 | |
| 	Publish(topic string, args ...any)
 | |
| 	// Subscribe subscribes to a topic
 | |
| 	Subscribe(topic string, fn interface{}) error
 | |
| }
 | |
| 
 | |
| // endregion dependencies
 | |
| 
 | |
| type Manager struct {
 | |
| 	cfg   *config.Config
 | |
| 	bus   EventBus
 | |
| 	db    InterfaceAndPeerDatabaseRepo
 | |
| 	wg    InterfaceController
 | |
| 	quick WgQuickController
 | |
| }
 | |
| 
 | |
| func NewWireGuardManager(
 | |
| 	cfg *config.Config,
 | |
| 	bus EventBus,
 | |
| 	wg InterfaceController,
 | |
| 	quick WgQuickController,
 | |
| 	db InterfaceAndPeerDatabaseRepo,
 | |
| ) (*Manager, error) {
 | |
| 	m := &Manager{
 | |
| 		cfg:   cfg,
 | |
| 		bus:   bus,
 | |
| 		wg:    wg,
 | |
| 		db:    db,
 | |
| 		quick: quick,
 | |
| 	}
 | |
| 
 | |
| 	m.connectToMessageBus()
 | |
| 
 | |
| 	return m, nil
 | |
| }
 | |
| 
 | |
| // StartBackgroundJobs starts background jobs like the expired peers check.
 | |
| // This method is non-blocking.
 | |
| func (m Manager) StartBackgroundJobs(ctx context.Context) {
 | |
| 	go m.runExpiredPeersCheck(ctx)
 | |
| }
 | |
| 
 | |
| func (m Manager) connectToMessageBus() {
 | |
| 	_ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent)
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserCreationEvent(user domain.User) {
 | |
| 	if !m.cfg.Core.CreateDefaultPeerOnCreation {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	slog.Debug("handling new user event", "user", user.Identifier)
 | |
| 
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	err := m.CreateDefaultPeer(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		slog.Error("failed to create default peer", "user", user.Identifier, "error", err)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) {
 | |
| 	if !m.cfg.Core.CreateDefaultPeer {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	userPeers, err := m.db.GetUserPeers(context.Background(), userId)
 | |
| 	if err != nil {
 | |
| 		slog.Error("failed to retrieve existing peers prior to default peer creation",
 | |
| 			"user", userId,
 | |
| 			"error", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if len(userPeers) > 0 {
 | |
| 		return // user already has peers, skip creation
 | |
| 	}
 | |
| 
 | |
| 	slog.Debug("handling new user login", "user", userId)
 | |
| 
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	err = m.CreateDefaultPeer(ctx, userId)
 | |
| 	if err != nil {
 | |
| 		slog.Error("failed to create default peer", "user", userId, "error", err)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserDisabledEvent(user domain.User) {
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		slog.Error("failed to retrieve peers for disabled user",
 | |
| 			"user", user.Identifier,
 | |
| 			"error", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for _, peer := range userPeers {
 | |
| 		if peer.IsDisabled() {
 | |
| 			continue // peer is already disabled
 | |
| 		}
 | |
| 
 | |
| 		slog.Debug("disabling peer due to user being disabled",
 | |
| 			"peer", peer.Identifier,
 | |
| 			"user", user.Identifier)
 | |
| 
 | |
| 		peer.Disabled = user.Disabled // set to user disabled timestamp
 | |
| 		peer.DisabledReason = domain.DisabledReasonUserDisabled
 | |
| 
 | |
| 		_, err := m.UpdatePeer(ctx, &peer)
 | |
| 		if err != nil {
 | |
| 			slog.Error("failed to disable peer for disabled user",
 | |
| 				"peer", peer.Identifier,
 | |
| 				"user", user.Identifier,
 | |
| 				"error", err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserEnabledEvent(user domain.User) {
 | |
| 	if !m.cfg.Core.ReEnablePeerAfterUserEnable {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		slog.Error("failed to retrieve peers for re-enabled user",
 | |
| 			"user", user.Identifier,
 | |
| 			"error", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for _, peer := range userPeers {
 | |
| 		if !peer.IsDisabled() {
 | |
| 			continue // peer is already active
 | |
| 		}
 | |
| 
 | |
| 		if peer.DisabledReason != domain.DisabledReasonUserDisabled {
 | |
| 			continue // peer was disabled for another reason
 | |
| 		}
 | |
| 
 | |
| 		slog.Debug("enabling peer due to user being enabled",
 | |
| 			"peer", peer.Identifier,
 | |
| 			"user", user.Identifier)
 | |
| 
 | |
| 		peer.Disabled = nil
 | |
| 		peer.DisabledReason = ""
 | |
| 
 | |
| 		_, err := m.UpdatePeer(ctx, &peer)
 | |
| 		if err != nil {
 | |
| 			slog.Error("failed to enable peer for enabled user",
 | |
| 				"peer", peer.Identifier,
 | |
| 				"user", user.Identifier,
 | |
| 				"error", err)
 | |
| 		}
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserDeletedEvent(user domain.User) {
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		slog.Error("failed to retrieve peers for deleted user",
 | |
| 			"user", user.Identifier,
 | |
| 			"error", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	deletionTime := time.Now()
 | |
| 	for _, peer := range userPeers {
 | |
| 		if peer.IsDisabled() {
 | |
| 			continue // peer is already disabled
 | |
| 		}
 | |
| 
 | |
| 		if m.cfg.Core.DeletePeerAfterUserDeleted {
 | |
| 			slog.Debug("deleting peer due to user being deleted",
 | |
| 				"peer", peer.Identifier,
 | |
| 				"user", user.Identifier)
 | |
| 
 | |
| 			if err := m.DeletePeer(ctx, peer.Identifier); err != nil {
 | |
| 				slog.Error("failed to delete peer for deleted user",
 | |
| 					"peer", peer.Identifier,
 | |
| 					"user", user.Identifier,
 | |
| 					"error", err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			slog.Debug("disabling peer due to user being deleted",
 | |
| 				"peer", peer.Identifier,
 | |
| 				"user", user.Identifier)
 | |
| 
 | |
| 			peer.UserIdentifier = "" // remove user reference
 | |
| 			peer.Disabled = &deletionTime
 | |
| 			peer.DisabledReason = domain.DisabledReasonUserDeleted
 | |
| 
 | |
| 			_, err := m.UpdatePeer(ctx, &peer)
 | |
| 			if err != nil {
 | |
| 				slog.Error("failed to disable peer for deleted user",
 | |
| 					"peer", peer.Identifier,
 | |
| 					"user", user.Identifier,
 | |
| 					"error", err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) runExpiredPeersCheck(ctx context.Context) {
 | |
| 	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
 | |
| 
 | |
| 	running := true
 | |
| 	for running {
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			running = false
 | |
| 			continue
 | |
| 		case <-time.After(m.cfg.Advanced.ExpiryCheckInterval):
 | |
| 			// select blocks until one of the cases evaluate to true
 | |
| 		}
 | |
| 
 | |
| 		interfaces, err := m.db.GetAllInterfaces(ctx)
 | |
| 		if err != nil {
 | |
| 			slog.Error("failed to fetch all interfaces for expiry check", "error", err)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		for _, iface := range interfaces {
 | |
| 			peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
 | |
| 			if err != nil {
 | |
| 				slog.Error("failed to fetch all peers from interface for expiry check",
 | |
| 					"interface", iface.Identifier,
 | |
| 					"error", err)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			m.checkExpiredPeers(ctx, peers)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) checkExpiredPeers(ctx context.Context, peers []domain.Peer) {
 | |
| 	now := time.Now()
 | |
| 
 | |
| 	for _, peer := range peers {
 | |
| 		if peer.IsExpired() && !peer.IsDisabled() {
 | |
| 			slog.Info("peer has expired, disabling", "peer", peer.Identifier)
 | |
| 
 | |
| 			peer.Disabled = &now
 | |
| 			peer.DisabledReason = domain.DisabledReasonExpired
 | |
| 
 | |
| 			_, err := m.UpdatePeer(ctx, &peer)
 | |
| 			if err != nil {
 | |
| 				slog.Error("failed to update expired peer", "peer", peer.Identifier, "error", err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 |