package adapters import ( "context" "errors" "fmt" "log/slog" "math/rand" "strings" "time" "github.com/glebarez/sqlite" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlserver" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/utils" "github.com/fedor-git/wg-portal-2/internal/config" "github.com/fedor-git/wg-portal-2/internal/domain" ) // SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed. var SchemaVersion uint64 = 1 // SysStat stores the current database schema version and the timestamp when it was applied. type SysStat struct { MigratedAt time.Time `gorm:"column:migrated_at"` SchemaVersion uint64 `gorm:"primaryKey,column:schema_version"` } // NodeSyncLock provides distributed locking for peer synchronization across nodes. // Only one node can hold the lock at a time, preventing concurrent syncs that cause // database contention and 504 Gateway Timeout errors. type NodeSyncLock struct { LockKey string `gorm:"primaryKey;column:lock_key"` NodeID string `gorm:"column:node_id;index"` LockedAt time.Time `gorm:"column:locked_at"` ExpiresAt time.Time `gorm:"column:expires_at;index"` // Auto-release stuck locks after 5 minutes } func (NodeSyncLock) TableName() string { return "node_sync_locks" } // GormLogger is a custom logger for Gorm, making it use slog type GormLogger struct { SlowThreshold time.Duration SourceField string IgnoreErrRecordNotFound bool Debug bool Silent bool prefix string } func NewLogger(slowThreshold time.Duration, debug bool) *GormLogger { return &GormLogger{ SlowThreshold: slowThreshold, Debug: debug, IgnoreErrRecordNotFound: true, Silent: false, SourceField: "src", prefix: "GORM-SQL: ", } } func (l *GormLogger) LogMode(level logger.LogLevel) logger.Interface { if level == logger.Silent { l.Silent = true } else { l.Silent = false } return l } func (l *GormLogger) Info(ctx context.Context, s string, args ...any) { if l.Silent { return } slog.InfoContext(ctx, l.prefix+s, args...) } func (l *GormLogger) Warn(ctx context.Context, s string, args ...any) { if l.Silent { return } slog.WarnContext(ctx, l.prefix+s, args...) } func (l *GormLogger) Error(ctx context.Context, s string, args ...any) { if l.Silent { return } slog.ErrorContext(ctx, l.prefix+s, args...) } func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.Silent { return } elapsed := time.Since(begin) sql, rows := fc() attrs := []any{ "rows", rows, "duration", elapsed, } if l.SourceField != "" { attrs = append(attrs, l.SourceField, utils.FileWithLineNum()) } if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.IgnoreErrRecordNotFound) { attrs = append(attrs, "error", err) slog.ErrorContext(ctx, l.prefix+sql, attrs...) return } if l.SlowThreshold != 0 && elapsed > l.SlowThreshold { slog.WarnContext(ctx, l.prefix+sql, attrs...) return } if l.Debug { slog.DebugContext(ctx, l.prefix+sql, attrs...) } } // applyConnectionPoolConfig applies the connection pool configuration from config to the database. func applyConnectionPoolConfig(db *gorm.DB, cfg config.DatabaseConfig) error { sqlDB, err := db.DB() if err != nil { return fmt.Errorf("failed to get underlying sql.DB: %w", err) } // Use configured values, or defaults if not specified maxOpen := cfg.MaxOpenConnections if maxOpen <= 0 { maxOpen = 50 } maxIdle := cfg.MaxIdleConnections if maxIdle <= 0 { maxIdle = 10 } maxLifetime := cfg.ConnectionMaxLifetime if maxLifetime <= 0 { maxLifetime = 3 * time.Minute } // For multi-node cluster with 24 nodes, ensure adequate pool size // Each node may need 2-3 concurrent connections for: // - Interface/peer sync operations // - Message bus event processing // - Regular API requests // Recommended: maxOpen = (nodes * 2-3) + buffer, e.g., 24*2.5 = 60 if maxOpen < 60 { slog.Warn("Connection pool size may be too small for multi-node cluster", "configured", maxOpen, "recommended_minimum", 60) } if maxIdle < 10 { maxIdle = 10 // Ensure minimum idle connections for cluster } sqlDB.SetMaxOpenConns(maxOpen) sqlDB.SetMaxIdleConns(maxIdle) sqlDB.SetConnMaxLifetime(maxLifetime) // Set connection max idle time to recycle stale connections faster // Prevents "too many connections" by recycling idle connections after 5 minutes sqlDB.SetConnMaxIdleTime(5 * time.Minute) slog.Info("Database connection pool configured", "max_open_connections", maxOpen, "max_idle_connections", maxIdle, "connection_max_lifetime", maxLifetime.String(), "connection_max_idle_time", "5m") return nil } // NewDatabase creates a new database connection and returns a Gorm database instance. func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) { var gormDb *gorm.DB var err error // Retry logic for initial database connection // With 24 nodes starting simultaneously, initial connection can fail with "Too many connections" // Increased to 50 retries to handle all 24 nodes with exponential backoff and jitter // This gives us up to ~2 minutes of retry time with proper staggering const maxRetries = 50 var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { if attempt > 0 { // Exponential backoff with random jitter to prevent thundering herd // With 24 nodes starting simultaneously, stagger retries much more aggressively // Exponential: 50ms, 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, 6.4s... // Jitter: ±100% randomness so not all nodes retry at same time baseWaitTime := time.Duration(50*(1< 10*time.Second { baseWaitTime = 10 * time.Second } // Full randomness jitter to spread retries across time jitterAmount := time.Duration(rand.Intn(int(baseWaitTime))) // ±100% jitter waitTime := baseWaitTime + jitterAmount slog.Warn("database connection attempt failed, retrying with exponential backoff", "attempt", attempt+1, "max_retries", maxRetries, "base_wait", baseWaitTime.String(), "actual_wait", waitTime.String(), "error", lastErr) time.Sleep(waitTime) } switch cfg.Type { case config.DatabaseMySQL: gormDb, err = gorm.Open(mysql.Open(cfg.DSN), &gorm.Config{ Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug), }) if err != nil { lastErr = fmt.Errorf("failed to open MySQL database: %w", err) if strings.Contains(err.Error(), "Too many connections") { continue // Retry on connection pool exhaustion } return nil, lastErr } // Apply connection pool configuration if err := applyConnectionPoolConfig(gormDb, cfg); err != nil { lastErr = fmt.Errorf("failed to configure MySQL connection pool: %w", err) continue } // Skip Ping() during initial connection to avoid connection pool exhaustion // during multi-node startup. gorm.Open() already validates the connection. // Ping would open another connection which might fail with "Too many connections" slog.Info("database connected successfully", "type", "MySQL", "attempt", attempt+1) return gormDb, nil case config.DatabaseMsSQL: gormDb, err = gorm.Open(sqlserver.Open(cfg.DSN), &gorm.Config{ Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug), }) if err != nil { lastErr = fmt.Errorf("failed to open sqlserver database: %w", err) if strings.Contains(err.Error(), "Too many connections") { continue } return nil, lastErr } // Apply connection pool configuration if err := applyConnectionPoolConfig(gormDb, cfg); err != nil { lastErr = fmt.Errorf("failed to configure MSSQL connection pool: %w", err) continue } slog.Info("database connected successfully", "type", "MSSQL", "attempt", attempt+1) return gormDb, nil case config.DatabasePostgres: gormDb, err = gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{ Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug), }) if err != nil { lastErr = fmt.Errorf("failed to open Postgres database: %w", err) if strings.Contains(err.Error(), "too many connections") { continue } return nil, lastErr } // Apply connection pool configuration if err := applyConnectionPoolConfig(gormDb, cfg); err != nil { lastErr = fmt.Errorf("failed to configure Postgres connection pool: %w", err) continue } slog.Info("database connected successfully", "type", "Postgres", "attempt", attempt+1) return gormDb, nil case config.DatabaseSQLite: gormDb, err = gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{ Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug), }) if err != nil { lastErr = fmt.Errorf("failed to open SQLite database: %w", err) continue } // SQLite doesn't benefit from connection pooling, set to 1 if err := applyConnectionPoolConfig(gormDb, config.DatabaseConfig{ MaxOpenConnections: 1, MaxIdleConnections: 1, ConnectionMaxLifetime: 0, }); err != nil { lastErr = fmt.Errorf("failed to configure SQLite connection pool: %w", err) continue } slog.Info("database connected successfully", "type", "SQLite", "attempt", attempt+1) return gormDb, nil default: return nil, fmt.Errorf("unknown database type: %s", cfg.Type) } } return nil, fmt.Errorf("failed to connect to database after %d attempts: %w", maxRetries, lastErr) } // SqlRepo is a SQL database repository implementation. // Currently, it supports MySQL, SQLite, Microsoft SQL and Postgresql database systems. type SqlRepo struct { db *gorm.DB cfg *config.Config } // NewSqlRepository creates a new SqlRepo instance. func NewSqlRepository(db *gorm.DB, cfg *config.Config) (*SqlRepo, error) { repo := &SqlRepo{ db: db, cfg: cfg, } if err := repo.preCheck(); err != nil { return nil, fmt.Errorf("failed to initialize database: %w", err) } if err := repo.migrate(); err != nil { return nil, fmt.Errorf("failed to initialize database: %w", err) } return repo, nil } func (r *SqlRepo) preCheck() error { // WireGuard Portal v1 database migration table type DatabaseMigrationInfo struct { Version string `gorm:"primaryKey"` Applied time.Time } // temporarily disable logger as the next request might fail (intentionally) r.db.Logger.LogMode(logger.Silent) defer func() { r.db.Logger.LogMode(logger.Info) }() lastVersion := DatabaseMigrationInfo{} err := r.db.Order("applied desc, version desc").FirstOrInit(&lastVersion).Error if err != nil { return nil // we probably don't have a V1 database =) } return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first", lastVersion.Version) } func (r *SqlRepo) migrate() error { slog.Debug("running migration: sys-stat", "result", r.db.AutoMigrate(&SysStat{})) slog.Debug("running migration: user", "result", r.db.AutoMigrate(&domain.User{})) slog.Debug("running migration: user webauthn credentials", "result", r.db.AutoMigrate(&domain.UserWebauthnCredential{})) slog.Debug("running migration: interface", "result", r.db.AutoMigrate(&domain.Interface{})) slog.Debug("running migration: peer", "result", r.db.AutoMigrate(&domain.Peer{})) slog.Debug("running migration: peer status", "result", r.db.AutoMigrate(&domain.PeerStatus{})) slog.Debug("running migration: interface status", "result", r.db.AutoMigrate(&domain.InterfaceStatus{})) slog.Debug("running migration: audit data", "result", r.db.AutoMigrate(&domain.AuditEntry{})) slog.Debug("running migration: node sync lock", "result", r.db.AutoMigrate(&NodeSyncLock{})) // Clean up deprecated columns from peer_statuses table (traffic accumulation refactor) // Previously we had PreviousSessionBytesReceived/Transmitted columns, now we use a simpler approach // These columns are safe to drop if they exist if r.db.Migrator().HasColumn("peer_statuses", "previous_session_received") { slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "previous_session_received") r.db.Migrator().DropColumn("peer_statuses", "previous_session_received") } if r.db.Migrator().HasColumn("peer_statuses", "previous_session_transmitted") { slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "previous_session_transmitted") r.db.Migrator().DropColumn("peer_statuses", "previous_session_transmitted") } // Clean up accumulated traffic columns (no longer needed - use current session traffic only) // Simplified to keep only current session bytes which are accurate from WireGuard if r.db.Migrator().HasColumn("peer_statuses", "accumulated_received") { slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "accumulated_received") r.db.Migrator().DropColumn("peer_statuses", "accumulated_received") } if r.db.Migrator().HasColumn("peer_statuses", "accumulated_transmitted") { slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "accumulated_transmitted") r.db.Migrator().DropColumn("peer_statuses", "accumulated_transmitted") } existingSysStat := SysStat{} r.db.Where("schema_version = ?", SchemaVersion).First(&existingSysStat) if existingSysStat.SchemaVersion == 0 { sysStat := SysStat{ MigratedAt: time.Now(), SchemaVersion: SchemaVersion, } if err := r.db.Create(&sysStat).Error; err != nil { return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", SchemaVersion, err) } slog.Debug("sys-stat entry written", "schema_version", SchemaVersion) } return nil } // region interfaces // GetInterface returns the interface with the given id. // If no interface is found, an error domain.ErrNotFound is returned. func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) { var in domain.Interface err := r.db.WithContext(ctx).Preload("Addresses").First(&in, id).Error if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrNotFound } if err != nil { return nil, err } return &in, nil } // GetInterfaceAndPeers returns the interface with the given id and all peers associated with it. // If no interface is found, an error domain.ErrNotFound is returned. func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( *domain.Interface, []domain.Peer, error, ) { in, err := r.GetInterface(ctx, id) if err != nil { return nil, nil, fmt.Errorf("failed to load interface: %w", err) } peers, err := r.GetInterfacePeers(ctx, id) if err != nil { return nil, nil, fmt.Errorf("failed to load peers: %w", err) } return in, peers, nil } // GetPeersStats returns the stats for the given peer ids. The order of the returned stats is not guaranteed. func (r *SqlRepo) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) { if len(ids) == 0 { return nil, nil } var stats []domain.PeerStatus err := r.db.WithContext(ctx).Where("identifier IN ?", ids).Find(&stats).Error if err != nil { return nil, err } return stats, nil } // GetAllInterfaces returns all interfaces. func (r *SqlRepo) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { var interfaces []domain.Interface err := r.db.WithContext(ctx).Preload("Addresses").Find(&interfaces).Error if err != nil { return nil, err } return interfaces, nil } // GetInterfaceStats returns the stats for the given interface id. // If no stats are found, an error domain.ErrNotFound is returned. func (r *SqlRepo) GetInterfaceStats(ctx context.Context, id domain.InterfaceIdentifier) ( *domain.InterfaceStatus, error, ) { if id == "" { return nil, nil } var stats []domain.InterfaceStatus err := r.db.WithContext(ctx).Where("identifier = ?", id).Find(&stats).Error if err != nil { return nil, err } if len(stats) == 0 { return nil, domain.ErrNotFound } stat := stats[0] return &stat, nil } // FindInterfaces returns all interfaces that match the given search string. // The search string is matched against the interface identifier and display name. func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error) { var users []domain.Interface searchValue := "%" + strings.ToLower(search) + "%" err := r.db.WithContext(ctx). Where("identifier LIKE ?", searchValue). Or("display_name LIKE ?", searchValue). Preload("Addresses"). Find(&users).Error if err != nil { return nil, err } return users, nil } // SaveInterface updates the interface with the given id. func (r *SqlRepo) SaveInterface( ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.Interface) (*domain.Interface, error), ) error { userInfo := domain.GetUserInfo(ctx) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { in, err := r.getOrCreateInterface(userInfo, tx, id) if err != nil { return err // return any error will roll back } in, err = updateFunc(in) if err != nil { return err } err = r.upsertInterface(userInfo, tx, in) if err != nil { return err } // return nil will commit the whole transaction return nil }) if err != nil { return err } return nil } func (r *SqlRepo) getOrCreateInterface( ui *domain.ContextUserInfo, tx *gorm.DB, id domain.InterfaceIdentifier, ) (*domain.Interface, error) { var in domain.Interface // interfaceDefaults will be applied to newly created interface records interfaceDefaults := domain.Interface{ BaseModel: domain.BaseModel{ CreatedBy: ui.UserId(), UpdatedBy: ui.UserId(), CreatedAt: time.Now(), UpdatedAt: time.Now(), }, Identifier: id, } err := tx.Preload("Addresses").Attrs(interfaceDefaults).FirstOrCreate(&in, id).Error if err != nil { return nil, err } return &in, nil } func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *domain.Interface) error { in.UpdatedBy = ui.UserId() in.UpdatedAt = time.Now() err := tx.Save(in).Error if err != nil { return err } // Only update addresses if they were explicitly set (not nil) // This prevents accidentally deleting all addresses when loading interface without preload if in.Addresses != nil { err = tx.Model(in).Association("Addresses").Replace(in.Addresses) if err != nil { return fmt.Errorf("failed to update interface addresses: %w", err) } } return nil } // DeleteInterface deletes the interface with the given id. func (r *SqlRepo) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { err := tx.Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error if err != nil { return err } err = tx.Delete(&domain.InterfaceStatus{InterfaceId: id}).Error if err != nil { return err } err = tx.Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error if err != nil { return err } return nil }) if err != nil { return err } return nil } // GetInterfaceIps returns a map of interface identifiers to their respective IP addresses. func (r *SqlRepo) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) { var ips []struct { domain.Cidr InterfaceId domain.InterfaceIdentifier `gorm:"column:interface_identifier"` } err := r.db.WithContext(ctx). Table("interface_addresses"). Joins("LEFT JOIN cidrs ON interface_addresses.cidr_cidr = cidrs.cidr"). Scan(&ips).Error if err != nil { return nil, err } result := make(map[domain.InterfaceIdentifier][]domain.Cidr) for _, ip := range ips { result[ip.InterfaceId] = append(result[ip.InterfaceId], ip.Cidr) } return result, nil } // endregion interfaces // region peers // GetPeer returns the peer with the given id. // If no peer is found, an error domain.ErrNotFound is returned. func (r *SqlRepo) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) { var peer domain.Peer err := r.db.WithContext(ctx).Preload("Addresses").First(&peer, id).Error if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrNotFound } if err != nil { return nil, err } return &peer, nil } // GetInterfacePeers returns all peers associated with the given interface id. func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) { var peers []domain.Peer err := r.db.WithContext(ctx).Preload("Addresses").Where("interface_identifier = ?", id).Find(&peers).Error if err != nil { return nil, err } return peers, nil } // FindInterfacePeers returns all peers associated with the given interface id that match the given search string. // The search string is matched against the peer identifier, display name and IP address. func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ( []domain.Peer, error, ) { var peers []domain.Peer searchValue := "%" + strings.ToLower(search) + "%" err := r.db.WithContext(ctx).Where("interface_identifier = ?", id). Where("identifier LIKE ?", searchValue). Or("display_name LIKE ?", searchValue). Or("iface_address_str_v LIKE ?", searchValue). Find(&peers).Error if err != nil { return nil, err } return peers, nil } // GetUserPeers returns all peers associated with the given user id. func (r *SqlRepo) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { var peers []domain.Peer err := r.db.WithContext(ctx).Preload("Addresses").Where("user_identifier = ?", id).Find(&peers).Error if err != nil { return nil, err } return peers, nil } // FindUserPeers returns all peers associated with the given user id that match the given search string. // The search string is matched against the peer identifier, display name and IP address. func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error) { var peers []domain.Peer searchValue := "%" + strings.ToLower(search) + "%" err := r.db.WithContext(ctx).Where("user_identifier = ?", id). Where("identifier LIKE ?", searchValue). Or("display_name LIKE ?", searchValue). Or("iface_address_str_v LIKE ?", searchValue). Find(&peers).Error if err != nil { return nil, err } return peers, nil } // SavePeer updates the peer with the given id. // If no existing peer is found, a new peer is created. // IMPORTANT: Also creates peer_status record to avoid deadlock during concurrent updates func (r *SqlRepo) SavePeer( ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.Peer) (*domain.Peer, error), ) error { userInfo := domain.GetUserInfo(ctx) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { peer, err := r.getOrCreatePeer(userInfo, tx, id) if err != nil { return err // return any error will roll back } peer, err = updateFunc(peer) if err != nil { return err } err = r.upsertPeer(userInfo, tx, peer) if err != nil { return err } // DEADLOCK FIX: Ensure peer_status record exists WITHOUT FOR UPDATE lock // This prevents lock contention when stats collection immediately tries to update it // Use explicit WHERE clause with correct column name (identifier, not peer_id) peerStatus := domain.PeerStatus{ PeerId: id, UpdatedAt: time.Now(), } // Explicit WHERE using correct column name for FirstOrCreate err = tx.Where("identifier = ?", id).FirstOrCreate(&peerStatus).Error if err != nil { // Log but don't fail - peer_status will be created on first stats update anyway slog.Debug("peer status record already exists or creation deferred", "peer", id) } // return nil will commit the whole transaction return nil }) if err != nil { return err } return nil } func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) ( *domain.Peer, error, ) { var peer domain.Peer // interfaceDefaults will be applied to newly created interface records interfaceDefaults := domain.Peer{ BaseModel: domain.BaseModel{ CreatedBy: ui.UserId(), UpdatedBy: ui.UserId(), CreatedAt: time.Now(), UpdatedAt: time.Now(), }, Identifier: id, } err := tx.Attrs(interfaceDefaults).FirstOrCreate(&peer, id).Error if err != nil { return nil, err } return &peer, nil } func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *domain.Peer) error { peer.UpdatedBy = ui.UserId() peer.UpdatedAt = time.Now() err := tx.Save(peer).Error if err != nil { return err } err = tx.Model(peer).Association("Addresses").Replace(peer.Interface.Addresses) if err != nil { return fmt.Errorf("failed to update peer addresses: %w", err) } return nil } // DeletePeer deletes the peer with the given id. // This also deletes the peer_addresses associations (many-to-many relationships with Cidr) // NOTE: We do NOT delete peer_status here. // The peer_status will be cleaned up by CleanOrphanedStatuses on all cluster nodes. // This ensures that other nodes can detect orphaned statuses and clean up their metrics. func (r *SqlRepo) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // First: delete peer_addresses many2many associations to avoid foreign key issues // This is critical because we use raw SQL and explicit deletion like in DeletePeersByIDs if err := tx.Table("peer_addresses").Where("peer_identifier = ?", string(id)).Delete(nil).Error; err != nil { return err } // Second: delete the peer itself err := tx.Where("identifier = ?", string(id)).Delete(&domain.Peer{}).Error if err != nil { return err } return nil }) if err != nil { return err } return nil } // GetPeerIps returns a map of peer identifiers to their respective IP addresses. func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]domain.Cidr, error) { var ips []struct { domain.Cidr PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"` } err := r.db.WithContext(ctx). Table("peer_addresses"). Joins("LEFT JOIN cidrs ON peer_addresses.cidr_cidr = cidrs.cidr"). Scan(&ips).Error if err != nil { return nil, err } result := make(map[domain.PeerIdentifier][]domain.Cidr) for _, ip := range ips { result[ip.PeerId] = append(result[ip.PeerId], ip.Cidr) } return result, nil } // GetUsedIpsPerSubnet returns a map of subnets to their respective used IP addresses. func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) ( map[domain.Cidr][]domain.Cidr, error, ) { var peerIps []struct { domain.Cidr PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"` } err := r.db.WithContext(ctx). Table("peer_addresses"). Joins("LEFT JOIN cidrs ON peer_addresses.cidr_cidr = cidrs.cidr"). Scan(&peerIps).Error if err != nil { return nil, fmt.Errorf("failed to fetch peer IP's: %w", err) } var interfaceIps []struct { domain.Cidr InterfaceId domain.InterfaceIdentifier `gorm:"column:interface_identifier"` } err = r.db.WithContext(ctx). Table("interface_addresses"). Joins("LEFT JOIN cidrs ON interface_addresses.cidr_cidr = cidrs.cidr"). Scan(&interfaceIps).Error if err != nil { return nil, fmt.Errorf("failed to fetch interface IP's: %w", err) } result := make(map[domain.Cidr][]domain.Cidr, len(subnets)) for _, ip := range interfaceIps { var subnet domain.Cidr // default empty subnet (if no subnet matches, we will add the IP to the empty subnet group) for _, s := range subnets { if s.Contains(ip.Cidr) { subnet = s break } } result[subnet] = append(result[subnet], ip.Cidr) } for _, ip := range peerIps { var subnet domain.Cidr // default empty subnet (if no subnet matches, we will add the IP to the empty subnet group) for _, s := range subnets { if s.Contains(ip.Cidr) { subnet = s break } } result[subnet] = append(result[subnet], ip.Cidr) } return result, nil } // endregion peers // region users // GetUser returns the user with the given id. // If no user is found, an error domain.ErrNotFound is returned. func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { var user domain.User err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").First(&user, id).Error if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrNotFound } if err != nil { return nil, err } return &user, nil } // GetUserByEmail returns the user with the given email. // If no user is found, an error domain.ErrNotFound is returned. // If multiple users are found, an error domain.ErrNotUnique is returned. func (r *SqlRepo) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) { var users []domain.User err := r.db.WithContext(ctx).Where("email = ?", email).Preload("WebAuthnCredentialList").Find(&users).Error if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrNotFound } if err != nil { return nil, err } if len(users) == 0 { return nil, domain.ErrNotFound } if len(users) > 1 { return nil, fmt.Errorf("found multiple users with email %s: %w", email, domain.ErrNotUnique) } user := users[0] return &user, nil } // GetUserByWebAuthnCredential returns the user with the given webauthn credential id. func (r *SqlRepo) GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error) { var credential domain.UserWebauthnCredential err := r.db.WithContext(ctx).Where("credential_identifier = ?", credentialIdBase64).First(&credential).Error if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return nil, domain.ErrNotFound } if err != nil { return nil, err } return r.GetUser(ctx, domain.UserIdentifier(credential.UserIdentifier)) } // GetAllUsers returns all users. func (r *SqlRepo) GetAllUsers(ctx context.Context) ([]domain.User, error) { var users []domain.User err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").Find(&users).Error if err != nil { return nil, err } return users, nil } // FindUsers returns all users that match the given search string. // The search string is matched against the user identifier, firstname, lastname and email. func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User, error) { var users []domain.User searchValue := "%" + strings.ToLower(search) + "%" err := r.db.WithContext(ctx). Where("identifier LIKE ?", searchValue). Or("firstname LIKE ?", searchValue). Or("lastname LIKE ?", searchValue). Or("email LIKE ?", searchValue). Preload("WebAuthnCredentialList"). Find(&users).Error if err != nil { return nil, err } return users, nil } // SaveUser updates the user with the given id. // If no user is found, a new user is created. func (r *SqlRepo) SaveUser( ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error), ) error { userInfo := domain.GetUserInfo(ctx) err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { user, err := r.getOrCreateUser(userInfo, tx, id) if err != nil { return err // return any error will roll back } user, err = updateFunc(user) if err != nil { return err } err = r.upsertUser(userInfo, tx, user) if err != nil { return err } // return nil will commit the whole transaction return nil }) if err != nil { return err } return nil } // DeleteUser deletes the user with the given id. func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) error { err := r.db.WithContext(ctx).Unscoped().Select(clause.Associations).Delete(&domain.User{Identifier: id}).Error if err != nil { return err } return nil } func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) ( *domain.User, error, ) { var user domain.User // userDefaults will be applied to newly created user records userDefaults := domain.User{ BaseModel: domain.BaseModel{ CreatedBy: ui.UserId(), UpdatedBy: ui.UserId(), CreatedAt: time.Now(), UpdatedAt: time.Now(), }, Identifier: id, Source: domain.UserSourceDatabase, IsAdmin: false, } err := tx.Attrs(userDefaults).FirstOrCreate(&user, id).Error if err != nil { return nil, err } return &user, nil } func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *domain.User) error { user.UpdatedBy = ui.UserId() user.UpdatedAt = time.Now() err := tx.Save(user).Error if err != nil { return err } err = tx.Session(&gorm.Session{FullSaveAssociations: true}).Unscoped().Model(user).Association("WebAuthnCredentialList").Unscoped().Replace(user.WebAuthnCredentialList) if err != nil { return fmt.Errorf("failed to update users webauthn credentials: %w", err) } return nil } // endregion users // region statistics // UpdateInterfaceStatus updates the interface status with the given id. // If no interface status is found, a new one is created. func (r *SqlRepo) UpdateInterfaceStatus( ctx context.Context, id domain.InterfaceIdentifier, updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error), ) error { err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { in, err := r.getOrCreateInterfaceStatus(tx, id) if err != nil { return err // return any error will roll back } in, err = updateFunc(in) if err != nil { return err } err = r.upsertInterfaceStatus(tx, in) if err != nil { return err } // return nil will commit the whole transaction return nil }) if err != nil { return err } return nil } func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) ( *domain.InterfaceStatus, error, ) { var in domain.InterfaceStatus // defaults will be applied to newly created record defaults := domain.InterfaceStatus{ InterfaceId: id, UpdatedAt: time.Now(), } err := tx.Attrs(defaults).FirstOrCreate(&in, id).Error if err != nil { return nil, err } return &in, nil } func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus) error { err := tx.Save(in).Error if err != nil { return err } return nil } // UpdatePeerStatus updates the peer status with the given id. // If no peer status is found, a new one is created. // OWNERSHIP CHECK: If peer has an OwnerNodeId set, skip update to prevent conflicts with owner node. // Includes automatic retry logic with exponential backoff for deadlock/conflict recovery. // Uses row-level FOR UPDATE locking to prevent concurrent modification conflicts. func (r *SqlRepo) UpdatePeerStatus( ctx context.Context, id domain.PeerIdentifier, updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error), ) error { // Retry logic with exponential backoff for deadlocks and conflicts // Increased to 10 retries (50ms, 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, 6.4s, 12.8s, 25.6s) // This handles high-concurrency scenarios with 24 cluster nodes and 100+ peers // Total max wait time: ~50 seconds for all retries const maxRetries = 10 var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { if attempt > 0 { // Exponential backoff with jitter to reduce synchronized retries // Base: 50ms * 2^(attempt-1) + random 0-50ms jitter baseWait := time.Duration(50*(1< 0 { // Exponential backoff with jitter to reduce synchronized retries // Base: 50ms * 2^(attempt-1) + random 0-50ms jitter baseWait := time.Duration(50*(1< 0 { // Exponential backoff: 50ms, 100ms, 200ms, 400ms, 800ms waitTime := time.Duration(50*(1< 0 { err := tx.Clauses(clause.OnConflict{ UpdateAll: true, }).CreateInBatches(allStatuses, 50).Error // batch insert in groups of 50 if err != nil { return err } } return nil }) if err == nil { return nil // success } lastErr = err // Check if this is a retryable error errMsg := err.Error() if !strings.Contains(errMsg, "Deadlock") && !strings.Contains(errMsg, "Record has changed") { return err // not a retryable error } // Continue to retry } return fmt.Errorf("BatchUpdatePeerStatuses failed after %d retries: %w", maxRetries, lastErr) } func (r *SqlRepo) getOrCreatePeerStatus(tx *gorm.DB, id domain.PeerIdentifier) (*domain.PeerStatus, error) { var in domain.PeerStatus // defaults will be applied to newly created record defaults := domain.PeerStatus{ PeerId: id, UpdatedAt: time.Now(), } // DEADLOCK FIX: Two-phase approach to reduce lock contention // 1. First try to get existing record with FOR UPDATE lock // (This avoids gap locks on non-existent rows) err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("identifier = ?", id).First(&in).Error if err == nil { // Record found and locked - proceed with update return &in, nil } if err != gorm.ErrRecordNotFound { // Unexpected error return nil, err } // Record doesn't exist - create it WITHOUT FOR UPDATE lock // This is safe because SavePeer pre-creates peer_status, so this is rare // Explicit WHERE clause using correct column name err = tx.Where("identifier = ?", id).Attrs(defaults).FirstOrCreate(&in).Error if err != nil { return nil, err } // Now that record exists, acquire FOR UPDATE lock for modifications err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("identifier = ?", id).First(&in).Error if err != nil { return nil, fmt.Errorf("failed to lock peer status after creation: %w", err) } return &in, nil } func (r *SqlRepo) getOrCreatePeerStatusForRead(tx *gorm.DB, id domain.PeerIdentifier) (*domain.PeerStatus, error) { var in domain.PeerStatus // defaults will be applied to newly created record defaults := domain.PeerStatus{ PeerId: id, UpdatedAt: time.Now(), } // For read-only access without modifications (used in conditionals) // No FOR UPDATE lock - allows concurrent reads err := tx.Attrs(defaults).FirstOrCreate(&in, id).Error if err != nil { return nil, err } return &in, nil } func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error { err := tx.Save(in).Error if err != nil { return err } return nil } // DeletePeerStatus deletes the peer status with the given id. func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error { err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error if err != nil { return err } return nil } // GetAllPeerStatuses returns all peer statuses from the database. func (r *SqlRepo) GetAllPeerStatuses(ctx context.Context) ([]domain.PeerStatus, error) { var statuses []domain.PeerStatus err := r.db.WithContext(ctx).Find(&statuses).Error if err != nil { return nil, err } return statuses, nil } // endregion statistics // region audit // SaveAuditEntry saves the given audit entry. func (r *SqlRepo) SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error { err := r.db.WithContext(ctx).Save(entry).Error if err != nil { return err } return nil } // GetAllAuditEntries retrieves all audit entries from the database. // The entries are ordered by timestamp, with the newest entries first. func (r *SqlRepo) GetAllAuditEntries(ctx context.Context) ([]domain.AuditEntry, error) { var entries []domain.AuditEntry err := r.db.WithContext(ctx).Order("created_at desc").Find(&entries).Error if err != nil { return nil, err } return entries, nil } // endregion audit func (r *SqlRepo) GetAllPeers(ctx context.Context) ([]domain.Peer, error) { var peers []domain.Peer // OPTIMIZED: Select only peer identifiers and basic fields to avoid N+1 queries // from Preload("Addresses") and Preload("Interface") which would create 600+ queries // This method is used mainly for cleanup validation, not for full peer data display err := r.db.WithContext(ctx). Select("id, identifier, display_name, interface_identifier"). Find(&peers).Error if err != nil { return nil, err } return peers, nil } // region node sync lock const ( syncLockKey = "peer:sync" syncLockDuration = 5 * time.Minute // Auto-release stuck locks after 5 minutes syncLockTimeout = 2 * time.Minute // Total wait time for acquiring lock ) // AcquireSyncLock attempts to acquire the global peer sync lock. // Returns nodeID that holds the lock, or empty string if we acquired it. // Uses exponential backoff to retry up to syncLockTimeout. func (r *SqlRepo) AcquireSyncLock(ctx context.Context, nodeID string) (acquiredBy string, err error) { // Use generic lock mechanism return r.AcquireLock(ctx, syncLockKey, nodeID, syncLockDuration) } // ReleaseSyncLock releases the global peer sync lock. func (r *SqlRepo) ReleaseSyncLock(ctx context.Context, nodeID string) error { return r.ReleaseLock(ctx, syncLockKey, nodeID) } // SyncAllPeersFromDBWithLock wraps sync to ensure only one node syncs at a time. // This prevents database contention and 504 Gateway Timeout cascades. func (r *SqlRepo) SyncAllPeersFromDBWithLock(ctx context.Context, nodeID string) (int, error) { // Try to acquire lock with timeout lockCtx, cancel := context.WithTimeout(ctx, syncLockTimeout) defer cancel() heldBy, err := r.AcquireSyncLock(lockCtx, nodeID) if err != nil { slog.Warn("[SYNC_LOCK] cannot acquire sync lock, another node is syncing", "held_by", heldBy, "self", nodeID, "error", err) // Return 0,0 to indicate no sync performed (not an error condition) return 0, nil } // Ensure lock is released after we're done (even if sync fails) defer func() { if relErr := r.ReleaseSyncLock(context.Background(), nodeID); relErr != nil { slog.Error("failed to release sync lock", "error", relErr) } // Clean up expired locks in background to prevent table bloat go func() { goCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if cleanErr := r.CleanupExpiredLocks(goCtx); cleanErr != nil { slog.Debug("cleanup expired locks failed", "error", cleanErr) } }() }() // Now perform the actual sync with the lock held count, err := r.SyncAllPeersFromDB(ctx) if err != nil { slog.Error("[SYNC_LOCK] sync failed while holding lock", "node_id", nodeID, "error", err) return count, err } slog.Info("[SYNC_LOCK] sync completed", "node_id", nodeID, "synced_count", count) return count, nil } // endregion node sync lock // SyncAllPeersFromDB synchronizes all peers from the database. func (r *SqlRepo) SyncAllPeersFromDB(ctx context.Context) (int, error) { slog.Debug("SyncAllPeersFromDB called, but no operation performed") return 0, nil } // region expired peers cleanup const ( expireCleanupLockKey = "expire:cleanup" expireCleanupLockDuration = 10 * time.Minute // Lock to ensure only one node performs cleanup expireCleanupTimeout = 30 * time.Second // Timeout for acquiring lock ) // GetExpiredPeers finds all peers with expiredAt in the past func (r *SqlRepo) GetExpiredPeers(ctx context.Context) ([]domain.Peer, error) { var peers []domain.Peer now := time.Now() err := r.db.WithContext(ctx). Where("expires_at IS NOT NULL AND expires_at < ?", now). Select("identifier, interface_identifier"). Find(&peers).Error if err != nil { return nil, err } return peers, nil } // DeletePeersByIDs deletes peers by their Identifier (public key) and their associated addresses // Used when cleaning up expired peers // IMPORTANT: Deletes peer_addresses associations first to avoid foreign key constraint violations // Also deletes peer statuses to avoid orphaned status records func (r *SqlRepo) DeletePeersByIDs(ctx context.Context, peerIDs []string) (int64, error) { if len(peerIDs) == 0 { return 0, nil } // Start transaction to ensure atomic deletion (associations first, then peers, then statuses) tx := r.db.WithContext(ctx).Begin() if tx.Error != nil { return 0, tx.Error } // First: delete peer_addresses many2many associations via raw SQL if err := tx.Table("peer_addresses").Where("peer_identifier IN ?", peerIDs).Delete(nil).Error; err != nil { tx.Rollback() return 0, err } // Second: delete peer statuses to avoid orphaned status records if err := tx.Where("identifier IN ?", peerIDs).Delete(&domain.PeerStatus{}).Error; err != nil { tx.Rollback() return 0, err } // Third: delete peers result := tx.Where("identifier IN ?", peerIDs).Delete(&domain.Peer{}) if result.Error != nil { tx.Rollback() return 0, result.Error } // Commit transaction if err := tx.Commit().Error; err != nil { return 0, err } return result.RowsAffected, nil } // FindAndDeleteExpiredPeersWithLock ensures only MASTER node deletes expired peers // Other nodes skip cleanup to avoid conflicts across multiple nodes // If this node is not configured as master, cleanup is skipped func (r *SqlRepo) FindAndDeleteExpiredPeersWithLock(ctx context.Context, nodeID string) (expiredPeerIDs []string, err error) { // Only MASTER node can delete expired peers if !r.cfg.Core.Master { slog.Debug("[EXPIRE_CLEANUP] this node is not master, skipping cleanup", "node_id", nodeID) return nil, nil } // Find expired peers expiredPeers, err := r.GetExpiredPeers(ctx) if err != nil { slog.Error("[EXPIRE_CLEANUP] failed to get expired peers", "error", err) return nil, err } if len(expiredPeers) == 0 { slog.Debug("[EXPIRE_CLEANUP] no expired peers found") return nil, nil } // Delete them from DB peerIDs := make([]string, len(expiredPeers)) for i, p := range expiredPeers { peerIDs[i] = string(p.Identifier) } deletedCount, err := r.DeletePeersByIDs(ctx, peerIDs) if err != nil { slog.Error("[EXPIRE_CLEANUP] failed to delete expired peers", "error", err, "count", len(peerIDs)) return nil, err } slog.Info("[EXPIRE_CLEANUP] deleted expired peers", "node_id", nodeID, "count", deletedCount, "peer_ids_count", len(peerIDs)) return peerIDs, nil } // endregion expired peers cleanup // region generic distributed locking (for reuse) // AcquireLock is a generic lock mechanism for various operations. // IMPORTANT: Uses simple INSERT without ON DUPLICATE KEY UPDATE to avoid deadlocks. // If lock already exists and is not expired, returns who holds it (no retry loop). func (r *SqlRepo) AcquireLock(ctx context.Context, lockKey string, nodeID string, duration time.Duration) (acquiredBy string, err error) { now := time.Now() // Clean up expired locks (but don't retry on failure) if delErr := r.db.WithContext(ctx). Where("lock_key = ? AND expires_at < ?", lockKey, now). Delete(&NodeSyncLock{}).Error; delErr != nil { slog.Debug("failed to clean expired locks", "lock_key", lockKey, "error", delErr) } // Try simple INSERT (no ON DUPLICATE KEY UPDATE to avoid deadlock) result := r.db.WithContext(ctx). Create(&NodeSyncLock{ LockKey: lockKey, NodeID: nodeID, LockedAt: now, ExpiresAt: now.Add(duration), }) if result.Error == nil { slog.Debug("[LOCK] acquired", "lock_key", lockKey, "node_id", nodeID) return "", nil } // INSERT failed - check who holds the lock (no retry, just check) var lock NodeSyncLock if err := r.db.WithContext(ctx). Where("lock_key = ?", lockKey). First(&lock).Error; err == nil { if lock.ExpiresAt.Before(now) { // Lock expired, clean it up (best effort) r.db.WithContext(ctx).Delete(&lock) return "", fmt.Errorf("lock expired, please retry") } // Lock is held by another node acquiredBy = lock.NodeID return acquiredBy, fmt.Errorf("lock held by %s", acquiredBy) } // Couldn't determine lock holder return "", fmt.Errorf("failed to acquire lock %s", lockKey) } // ReleaseLock releases a distributed lock func (r *SqlRepo) ReleaseLock(ctx context.Context, lockKey string, nodeID string) error { // Only delete lock if it belongs to us AND it hasn't expired yet // This prevents deadlocks from competing DELETE operations on expired locks result := r.db.WithContext(ctx). Where("lock_key = ? AND node_id = ? AND expires_at > NOW()", lockKey, nodeID). Delete(&NodeSyncLock{}) if result.Error != nil { // Log but don't fail - lock might already be released or expired slog.Debug("failed to release lock", "lockKey", lockKey, "nodeID", nodeID, "error", result.Error) return nil } if result.RowsAffected == 0 { // Lock was already released or expired - not an error slog.Debug("lock already released or expired", "lockKey", lockKey, "nodeID", nodeID) return nil } return nil } // CleanupExpiredLocks removes all expired locks from the database // This prevents deadlocks from piling up expired lock rows func (r *SqlRepo) CleanupExpiredLocks(ctx context.Context) error { result := r.db.WithContext(ctx). Where("expires_at < ?", time.Now()). Delete(&NodeSyncLock{}) if result.Error != nil { slog.Warn("failed to cleanup expired locks", "error", result.Error) return result.Error } if result.RowsAffected > 0 { slog.Debug("cleaned up expired locks", "rows_deleted", result.RowsAffected) } return nil } // endregion generic distributed locking