mirror of https://github.com/h44z/wg-portal.git
				
				
				
			
		
			
				
	
	
		
			243 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			243 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
| package wireguard
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/h44z/wg-portal/internal/app"
 | |
| 	"github.com/sirupsen/logrus"
 | |
| 
 | |
| 	evbus "github.com/vardius/message-bus"
 | |
| 
 | |
| 	"github.com/h44z/wg-portal/internal/config"
 | |
| 	"github.com/h44z/wg-portal/internal/domain"
 | |
| )
 | |
| 
 | |
| type Manager struct {
 | |
| 	cfg *config.Config
 | |
| 	bus evbus.MessageBus
 | |
| 
 | |
| 	db    InterfaceAndPeerDatabaseRepo
 | |
| 	wg    InterfaceController
 | |
| 	quick WgQuickController
 | |
| }
 | |
| 
 | |
| func NewWireGuardManager(
 | |
| 	cfg *config.Config,
 | |
| 	bus evbus.MessageBus,
 | |
| 	wg InterfaceController,
 | |
| 	quick WgQuickController,
 | |
| 	db InterfaceAndPeerDatabaseRepo,
 | |
| ) (*Manager, error) {
 | |
| 	m := &Manager{
 | |
| 		cfg:   cfg,
 | |
| 		bus:   bus,
 | |
| 		wg:    wg,
 | |
| 		db:    db,
 | |
| 		quick: quick,
 | |
| 	}
 | |
| 
 | |
| 	m.connectToMessageBus()
 | |
| 
 | |
| 	return m, nil
 | |
| }
 | |
| 
 | |
| func (m Manager) StartBackgroundJobs(ctx context.Context) {
 | |
| 	go m.runExpiredPeersCheck(ctx)
 | |
| }
 | |
| 
 | |
| func (m Manager) connectToMessageBus() {
 | |
| 	_ = m.bus.Subscribe(app.TopicUserCreated, m.handleUserCreationEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicAuthLogin, m.handleUserLoginEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicUserDisabled, m.handleUserDisabledEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicUserEnabled, m.handleUserEnabledEvent)
 | |
| 	_ = m.bus.Subscribe(app.TopicUserDeleted, m.handleUserDeletedEvent)
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserCreationEvent(user *domain.User) {
 | |
| 	if !m.cfg.Core.CreateDefaultPeerOnCreation {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	logrus.Tracef("handling new user event for %s", user.Identifier)
 | |
| 
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	err := m.CreateDefaultPeer(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		logrus.Errorf("failed to create default peer for %s: %v", user.Identifier, err)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) {
 | |
| 	if !m.cfg.Core.CreateDefaultPeer {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	userPeers, err := m.db.GetUserPeers(context.Background(), userId)
 | |
| 	if err != nil {
 | |
| 		logrus.Errorf("failed to retrieve existing peers for %s prior to default peer creation: %v", userId, err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if len(userPeers) > 0 {
 | |
| 		return // user already has peers, skip creation
 | |
| 	}
 | |
| 
 | |
| 	logrus.Tracef("handling new user login for %s", userId)
 | |
| 
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	err = m.CreateDefaultPeer(ctx, userId)
 | |
| 	if err != nil {
 | |
| 		logrus.Errorf("failed to create default peer for %s: %v", userId, err)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserDisabledEvent(user domain.User) {
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		logrus.Errorf("failed to retrieve peers for disabled user %s: %v", user.Identifier, err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for _, peer := range userPeers {
 | |
| 		if peer.IsDisabled() {
 | |
| 			continue // peer is already disabled
 | |
| 		}
 | |
| 
 | |
| 		logrus.Debugf("disabling peer %s due to user %s being disabled", peer.Identifier, user.Identifier)
 | |
| 
 | |
| 		peer.Disabled = user.Disabled // set to user disabled timestamp
 | |
| 		peer.DisabledReason = domain.DisabledReasonUserDisabled
 | |
| 
 | |
| 		_, err := m.UpdatePeer(ctx, &peer)
 | |
| 		if err != nil {
 | |
| 			logrus.Errorf("failed to disable peer %s for disabled user %s: %v",
 | |
| 				peer.Identifier, user.Identifier, err)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserEnabledEvent(user domain.User) {
 | |
| 	if !m.cfg.Core.ReEnablePeerAfterUserEnable {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		logrus.Errorf("failed to retrieve peers for re-enabled user %s: %v", user.Identifier, err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for _, peer := range userPeers {
 | |
| 		if !peer.IsDisabled() {
 | |
| 			continue // peer is already active
 | |
| 		}
 | |
| 
 | |
| 		if peer.DisabledReason != domain.DisabledReasonUserDisabled {
 | |
| 			continue // peer was disabled for another reason
 | |
| 		}
 | |
| 
 | |
| 		logrus.Debugf("enabling peer %s due to user %s being enabled", peer.Identifier, user.Identifier)
 | |
| 
 | |
| 		peer.Disabled = nil
 | |
| 		peer.DisabledReason = ""
 | |
| 
 | |
| 		_, err := m.UpdatePeer(ctx, &peer)
 | |
| 		if err != nil {
 | |
| 			logrus.Errorf("failed to enable peer %s for enabled user %s: %v",
 | |
| 				peer.Identifier, user.Identifier, err)
 | |
| 		}
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (m Manager) handleUserDeletedEvent(user domain.User) {
 | |
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo())
 | |
| 	userPeers, err := m.db.GetUserPeers(ctx, user.Identifier)
 | |
| 	if err != nil {
 | |
| 		logrus.Errorf("failed to retrieve peers for deleted user %s: %v", user.Identifier, err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	deletionTime := time.Now()
 | |
| 	for _, peer := range userPeers {
 | |
| 		if peer.IsDisabled() {
 | |
| 			continue // peer is already disabled
 | |
| 		}
 | |
| 
 | |
| 		if m.cfg.Core.DeletePeerAfterUserDeleted {
 | |
| 			logrus.Debugf("deleting peer %s due to user %s being deleted", peer.Identifier, user.Identifier)
 | |
| 
 | |
| 			if err := m.DeletePeer(ctx, peer.Identifier); err != nil {
 | |
| 				logrus.Errorf("failed to delete peer %s for deleted user %s: %v",
 | |
| 					peer.Identifier, user.Identifier, err)
 | |
| 			}
 | |
| 		} else {
 | |
| 			logrus.Debugf("disabling peer %s due to user %s being deleted", peer.Identifier, user.Identifier)
 | |
| 
 | |
| 			peer.UserIdentifier = "" // remove user reference
 | |
| 			peer.Disabled = &deletionTime
 | |
| 			peer.DisabledReason = domain.DisabledReasonUserDeleted
 | |
| 
 | |
| 			_, err := m.UpdatePeer(ctx, &peer)
 | |
| 			if err != nil {
 | |
| 				logrus.Errorf("failed to disable peer %s for deleted user %s: %v",
 | |
| 					peer.Identifier, user.Identifier, err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) runExpiredPeersCheck(ctx context.Context) {
 | |
| 	ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo())
 | |
| 
 | |
| 	running := true
 | |
| 	for running {
 | |
| 		select {
 | |
| 		case <-ctx.Done():
 | |
| 			running = false
 | |
| 			continue
 | |
| 		case <-time.After(m.cfg.Advanced.ExpiryCheckInterval):
 | |
| 			// select blocks until one of the cases evaluate to true
 | |
| 		}
 | |
| 
 | |
| 		interfaces, err := m.db.GetAllInterfaces(ctx)
 | |
| 		if err != nil {
 | |
| 			logrus.Errorf("failed to fetch all interfaces for expiry check: %v", err)
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		for _, iface := range interfaces {
 | |
| 			peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
 | |
| 			if err != nil {
 | |
| 				logrus.Errorf("failed to fetch all peers from interface %s for expiry check: %v", iface.Identifier, err)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			m.checkExpiredPeers(ctx, peers)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (m Manager) checkExpiredPeers(ctx context.Context, peers []domain.Peer) {
 | |
| 	now := time.Now()
 | |
| 
 | |
| 	for _, peer := range peers {
 | |
| 		if peer.IsExpired() && !peer.IsDisabled() {
 | |
| 			logrus.Infof("peer %s has expired, disabling...", peer.Identifier)
 | |
| 
 | |
| 			peer.Disabled = &now
 | |
| 			peer.DisabledReason = domain.DisabledReasonExpired
 | |
| 
 | |
| 			_, err := m.UpdatePeer(ctx, &peer)
 | |
| 			if err != nil {
 | |
| 				logrus.Errorf("failed to update expired peer %s: %v", peer.Identifier, err)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 |