mirror of https://github.com/h44z/wg-portal.git
				
				
				
			fix duplicate creation of default peer (#437)
This commit is contained in:
		
							parent
							
								
									ab9995350f
								
							
						
					
					
						commit
						8816165260
					
				|  | @ -3,6 +3,7 @@ package wireguard | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"log/slog" | 	"log/slog" | ||||||
|  | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/h44z/wg-portal/internal/app" | 	"github.com/h44z/wg-portal/internal/app" | ||||||
|  | @ -76,6 +77,8 @@ type Manager struct { | ||||||
| 	db    InterfaceAndPeerDatabaseRepo | 	db    InterfaceAndPeerDatabaseRepo | ||||||
| 	wg    InterfaceController | 	wg    InterfaceController | ||||||
| 	quick WgQuickController | 	quick WgQuickController | ||||||
|  | 
 | ||||||
|  | 	userLockMap *sync.Map | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewWireGuardManager( | func NewWireGuardManager( | ||||||
|  | @ -91,6 +94,7 @@ func NewWireGuardManager( | ||||||
| 		wg:          wg, | 		wg:          wg, | ||||||
| 		db:          db, | 		db:          db, | ||||||
| 		quick:       quick, | 		quick:       quick, | ||||||
|  | 		userLockMap: &sync.Map{}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	m.connectToMessageBus() | 	m.connectToMessageBus() | ||||||
|  | @ -117,6 +121,12 @@ func (m Manager) handleUserCreationEvent(user domain.User) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	_, loaded := m.userLockMap.LoadOrStore(user.Identifier, "create") | ||||||
|  | 	if loaded { | ||||||
|  | 		return // another goroutine is already handling this user
 | ||||||
|  | 	} | ||||||
|  | 	defer m.userLockMap.Delete(user.Identifier) | ||||||
|  | 
 | ||||||
| 	slog.Debug("handling new user event", "user", user.Identifier) | 	slog.Debug("handling new user event", "user", user.Identifier) | ||||||
| 
 | 
 | ||||||
| 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) | 	ctx := domain.SetUserInfo(context.Background(), domain.SystemAdminContextUserInfo()) | ||||||
|  | @ -132,6 +142,12 @@ func (m Manager) handleUserLoginEvent(userId domain.UserIdentifier) { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	_, loaded := m.userLockMap.LoadOrStore(userId, "login") | ||||||
|  | 	if loaded { | ||||||
|  | 		return // another goroutine is already handling this user
 | ||||||
|  | 	} | ||||||
|  | 	defer m.userLockMap.Delete(userId) | ||||||
|  | 
 | ||||||
| 	userPeers, err := m.db.GetUserPeers(context.Background(), userId) | 	userPeers, err := m.db.GetUserPeers(context.Background(), userId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		slog.Error("failed to retrieve existing peers prior to default peer creation", | 		slog.Error("failed to retrieve existing peers prior to default peer creation", | ||||||
|  |  | ||||||
|  | @ -5,6 +5,7 @@ import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"log/slog" | 	"log/slog" | ||||||
|  | 	"slices" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/h44z/wg-portal/internal/app" | 	"github.com/h44z/wg-portal/internal/app" | ||||||
|  | @ -23,12 +24,24 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti | ||||||
| 		return fmt.Errorf("failed to fetch all interfaces: %w", err) | 		return fmt.Errorf("failed to fetch all interfaces: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	userPeers, err := m.db.GetUserPeers(context.Background(), userId) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("failed to retrieve existing peers prior to default peer creation: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	var newPeers []domain.Peer | 	var newPeers []domain.Peer | ||||||
| 	for _, iface := range existingInterfaces { | 	for _, iface := range existingInterfaces { | ||||||
| 		if iface.Type != domain.InterfaceTypeServer { | 		if iface.Type != domain.InterfaceTypeServer { | ||||||
| 			continue // only create default peers for server interfaces
 | 			continue // only create default peers for server interfaces
 | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		peerAlreadyCreated := slices.ContainsFunc(userPeers, func(peer domain.Peer) bool { | ||||||
|  | 			return peer.InterfaceIdentifier == iface.Identifier | ||||||
|  | 		}) | ||||||
|  | 		if peerAlreadyCreated { | ||||||
|  | 			continue // skip creation if a peer already exists for this interface
 | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		peer, err := m.PreparePeer(ctx, iface.Identifier) | 		peer, err := m.PreparePeer(ctx, iface.Identifier) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err) | 			return fmt.Errorf("failed to create default peer for interface %s: %w", iface.Identifier, err) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue