mirror of https://github.com/h44z/wg-portal.git
1862 lines
55 KiB
Go
1862 lines
55 KiB
Go
package adapters
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"math/rand"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/glebarez/sqlite"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlserver"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/logger"
|
|
"gorm.io/gorm/utils"
|
|
|
|
"github.com/fedor-git/wg-portal-2/internal/config"
|
|
"github.com/fedor-git/wg-portal-2/internal/domain"
|
|
)
|
|
|
|
// SchemaVersion describes the current database schema version. It must be incremented if a manual migration is needed.
|
|
var SchemaVersion uint64 = 1
|
|
|
|
// SysStat stores the current database schema version and the timestamp when it was applied.
|
|
type SysStat struct {
|
|
MigratedAt time.Time `gorm:"column:migrated_at"`
|
|
SchemaVersion uint64 `gorm:"primaryKey,column:schema_version"`
|
|
}
|
|
|
|
// NodeSyncLock provides distributed locking for peer synchronization across nodes.
|
|
// Only one node can hold the lock at a time, preventing concurrent syncs that cause
|
|
// database contention and 504 Gateway Timeout errors.
|
|
type NodeSyncLock struct {
|
|
LockKey string `gorm:"primaryKey;column:lock_key"`
|
|
NodeID string `gorm:"column:node_id;index"`
|
|
LockedAt time.Time `gorm:"column:locked_at"`
|
|
ExpiresAt time.Time `gorm:"column:expires_at;index"` // Auto-release stuck locks after 5 minutes
|
|
}
|
|
|
|
func (NodeSyncLock) TableName() string {
|
|
return "node_sync_locks"
|
|
}
|
|
|
|
// GormLogger is a custom logger for Gorm, making it use slog
|
|
type GormLogger struct {
|
|
SlowThreshold time.Duration
|
|
SourceField string
|
|
IgnoreErrRecordNotFound bool
|
|
Debug bool
|
|
Silent bool
|
|
|
|
prefix string
|
|
}
|
|
|
|
func NewLogger(slowThreshold time.Duration, debug bool) *GormLogger {
|
|
return &GormLogger{
|
|
SlowThreshold: slowThreshold,
|
|
Debug: debug,
|
|
IgnoreErrRecordNotFound: true,
|
|
Silent: false,
|
|
SourceField: "src",
|
|
prefix: "GORM-SQL: ",
|
|
}
|
|
}
|
|
|
|
func (l *GormLogger) LogMode(level logger.LogLevel) logger.Interface {
|
|
if level == logger.Silent {
|
|
l.Silent = true
|
|
} else {
|
|
l.Silent = false
|
|
}
|
|
return l
|
|
}
|
|
|
|
func (l *GormLogger) Info(ctx context.Context, s string, args ...any) {
|
|
if l.Silent {
|
|
return
|
|
}
|
|
slog.InfoContext(ctx, l.prefix+s, args...)
|
|
}
|
|
|
|
func (l *GormLogger) Warn(ctx context.Context, s string, args ...any) {
|
|
if l.Silent {
|
|
return
|
|
}
|
|
slog.WarnContext(ctx, l.prefix+s, args...)
|
|
}
|
|
|
|
func (l *GormLogger) Error(ctx context.Context, s string, args ...any) {
|
|
if l.Silent {
|
|
return
|
|
}
|
|
slog.ErrorContext(ctx, l.prefix+s, args...)
|
|
}
|
|
|
|
func (l *GormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
|
if l.Silent {
|
|
return
|
|
}
|
|
|
|
elapsed := time.Since(begin)
|
|
sql, rows := fc()
|
|
|
|
attrs := []any{
|
|
"rows", rows,
|
|
"duration", elapsed,
|
|
}
|
|
|
|
if l.SourceField != "" {
|
|
attrs = append(attrs, l.SourceField, utils.FileWithLineNum())
|
|
}
|
|
|
|
if err != nil && !(errors.Is(err, gorm.ErrRecordNotFound) && l.IgnoreErrRecordNotFound) {
|
|
attrs = append(attrs, "error", err)
|
|
slog.ErrorContext(ctx, l.prefix+sql, attrs...)
|
|
return
|
|
}
|
|
|
|
if l.SlowThreshold != 0 && elapsed > l.SlowThreshold {
|
|
slog.WarnContext(ctx, l.prefix+sql, attrs...)
|
|
return
|
|
}
|
|
|
|
if l.Debug {
|
|
slog.DebugContext(ctx, l.prefix+sql, attrs...)
|
|
}
|
|
}
|
|
|
|
// applyConnectionPoolConfig applies the connection pool configuration from config to the database.
|
|
func applyConnectionPoolConfig(db *gorm.DB, cfg config.DatabaseConfig) error {
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get underlying sql.DB: %w", err)
|
|
}
|
|
|
|
// Use configured values, or defaults if not specified
|
|
maxOpen := cfg.MaxOpenConnections
|
|
if maxOpen <= 0 {
|
|
maxOpen = 50
|
|
}
|
|
|
|
maxIdle := cfg.MaxIdleConnections
|
|
if maxIdle <= 0 {
|
|
maxIdle = 10
|
|
}
|
|
|
|
maxLifetime := cfg.ConnectionMaxLifetime
|
|
if maxLifetime <= 0 {
|
|
maxLifetime = 3 * time.Minute
|
|
}
|
|
|
|
// For multi-node cluster with 24 nodes, ensure adequate pool size
|
|
// Each node may need 2-3 concurrent connections for:
|
|
// - Interface/peer sync operations
|
|
// - Message bus event processing
|
|
// - Regular API requests
|
|
// Recommended: maxOpen = (nodes * 2-3) + buffer, e.g., 24*2.5 = 60
|
|
if maxOpen < 60 {
|
|
slog.Warn("Connection pool size may be too small for multi-node cluster",
|
|
"configured", maxOpen, "recommended_minimum", 60)
|
|
}
|
|
if maxIdle < 10 {
|
|
maxIdle = 10 // Ensure minimum idle connections for cluster
|
|
}
|
|
|
|
sqlDB.SetMaxOpenConns(maxOpen)
|
|
sqlDB.SetMaxIdleConns(maxIdle)
|
|
sqlDB.SetConnMaxLifetime(maxLifetime)
|
|
// Set connection max idle time to recycle stale connections faster
|
|
// Prevents "too many connections" by recycling idle connections after 5 minutes
|
|
sqlDB.SetConnMaxIdleTime(5 * time.Minute)
|
|
|
|
slog.Info("Database connection pool configured",
|
|
"max_open_connections", maxOpen,
|
|
"max_idle_connections", maxIdle,
|
|
"connection_max_lifetime", maxLifetime.String(),
|
|
"connection_max_idle_time", "5m")
|
|
|
|
return nil
|
|
}
|
|
|
|
// NewDatabase creates a new database connection and returns a Gorm database instance.
|
|
func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
|
var gormDb *gorm.DB
|
|
var err error
|
|
|
|
// Retry logic for initial database connection
|
|
// With 24 nodes starting simultaneously, initial connection can fail with "Too many connections"
|
|
// Increased to 50 retries to handle all 24 nodes with exponential backoff and jitter
|
|
// This gives us up to ~2 minutes of retry time with proper staggering
|
|
const maxRetries = 50
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
if attempt > 0 {
|
|
// Exponential backoff with random jitter to prevent thundering herd
|
|
// With 24 nodes starting simultaneously, stagger retries much more aggressively
|
|
// Exponential: 50ms, 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, 6.4s...
|
|
// Jitter: ±100% randomness so not all nodes retry at same time
|
|
baseWaitTime := time.Duration(50*(1<<uint(attempt-1))) * time.Millisecond
|
|
// Cap base wait at 10 seconds to avoid excessive delays
|
|
if baseWaitTime > 10*time.Second {
|
|
baseWaitTime = 10 * time.Second
|
|
}
|
|
// Full randomness jitter to spread retries across time
|
|
jitterAmount := time.Duration(rand.Intn(int(baseWaitTime))) // ±100% jitter
|
|
waitTime := baseWaitTime + jitterAmount
|
|
|
|
slog.Warn("database connection attempt failed, retrying with exponential backoff",
|
|
"attempt", attempt+1,
|
|
"max_retries", maxRetries,
|
|
"base_wait", baseWaitTime.String(),
|
|
"actual_wait", waitTime.String(),
|
|
"error", lastErr)
|
|
time.Sleep(waitTime)
|
|
}
|
|
|
|
switch cfg.Type {
|
|
case config.DatabaseMySQL:
|
|
gormDb, err = gorm.Open(mysql.Open(cfg.DSN), &gorm.Config{
|
|
Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
|
|
})
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("failed to open MySQL database: %w", err)
|
|
if strings.Contains(err.Error(), "Too many connections") {
|
|
continue // Retry on connection pool exhaustion
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
// Apply connection pool configuration
|
|
if err := applyConnectionPoolConfig(gormDb, cfg); err != nil {
|
|
lastErr = fmt.Errorf("failed to configure MySQL connection pool: %w", err)
|
|
continue
|
|
}
|
|
|
|
// Skip Ping() during initial connection to avoid connection pool exhaustion
|
|
// during multi-node startup. gorm.Open() already validates the connection.
|
|
// Ping would open another connection which might fail with "Too many connections"
|
|
slog.Info("database connected successfully", "type", "MySQL", "attempt", attempt+1)
|
|
return gormDb, nil
|
|
|
|
case config.DatabaseMsSQL:
|
|
gormDb, err = gorm.Open(sqlserver.Open(cfg.DSN), &gorm.Config{
|
|
Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
|
|
})
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("failed to open sqlserver database: %w", err)
|
|
if strings.Contains(err.Error(), "Too many connections") {
|
|
continue
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
// Apply connection pool configuration
|
|
if err := applyConnectionPoolConfig(gormDb, cfg); err != nil {
|
|
lastErr = fmt.Errorf("failed to configure MSSQL connection pool: %w", err)
|
|
continue
|
|
}
|
|
|
|
slog.Info("database connected successfully", "type", "MSSQL", "attempt", attempt+1)
|
|
return gormDb, nil
|
|
|
|
case config.DatabasePostgres:
|
|
gormDb, err = gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{
|
|
Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
|
|
})
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("failed to open Postgres database: %w", err)
|
|
if strings.Contains(err.Error(), "too many connections") {
|
|
continue
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
// Apply connection pool configuration
|
|
if err := applyConnectionPoolConfig(gormDb, cfg); err != nil {
|
|
lastErr = fmt.Errorf("failed to configure Postgres connection pool: %w", err)
|
|
continue
|
|
}
|
|
|
|
slog.Info("database connected successfully", "type", "Postgres", "attempt", attempt+1)
|
|
return gormDb, nil
|
|
|
|
case config.DatabaseSQLite:
|
|
gormDb, err = gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{
|
|
Logger: NewLogger(cfg.SlowQueryThreshold, cfg.Debug),
|
|
})
|
|
if err != nil {
|
|
lastErr = fmt.Errorf("failed to open SQLite database: %w", err)
|
|
continue
|
|
}
|
|
|
|
// SQLite doesn't benefit from connection pooling, set to 1
|
|
if err := applyConnectionPoolConfig(gormDb, config.DatabaseConfig{
|
|
MaxOpenConnections: 1,
|
|
MaxIdleConnections: 1,
|
|
ConnectionMaxLifetime: 0,
|
|
}); err != nil {
|
|
lastErr = fmt.Errorf("failed to configure SQLite connection pool: %w", err)
|
|
continue
|
|
}
|
|
|
|
slog.Info("database connected successfully", "type", "SQLite", "attempt", attempt+1)
|
|
return gormDb, nil
|
|
|
|
default:
|
|
return nil, fmt.Errorf("unknown database type: %s", cfg.Type)
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to connect to database after %d attempts: %w", maxRetries, lastErr)
|
|
}
|
|
|
|
// SqlRepo is a SQL database repository implementation.
|
|
// Currently, it supports MySQL, SQLite, Microsoft SQL and Postgresql database systems.
|
|
type SqlRepo struct {
|
|
db *gorm.DB
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewSqlRepository creates a new SqlRepo instance.
|
|
func NewSqlRepository(db *gorm.DB, cfg *config.Config) (*SqlRepo, error) {
|
|
repo := &SqlRepo{
|
|
db: db,
|
|
cfg: cfg,
|
|
}
|
|
|
|
if err := repo.preCheck(); err != nil {
|
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
|
}
|
|
|
|
if err := repo.migrate(); err != nil {
|
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
|
}
|
|
|
|
return repo, nil
|
|
}
|
|
|
|
func (r *SqlRepo) preCheck() error {
|
|
// WireGuard Portal v1 database migration table
|
|
type DatabaseMigrationInfo struct {
|
|
Version string `gorm:"primaryKey"`
|
|
Applied time.Time
|
|
}
|
|
|
|
// temporarily disable logger as the next request might fail (intentionally)
|
|
r.db.Logger.LogMode(logger.Silent)
|
|
defer func() { r.db.Logger.LogMode(logger.Info) }()
|
|
|
|
lastVersion := DatabaseMigrationInfo{}
|
|
err := r.db.Order("applied desc, version desc").FirstOrInit(&lastVersion).Error
|
|
if err != nil {
|
|
return nil // we probably don't have a V1 database =)
|
|
}
|
|
|
|
return fmt.Errorf("detected a WireGuard Portal V1 database (version: %s) - please migrate first",
|
|
lastVersion.Version)
|
|
}
|
|
|
|
func (r *SqlRepo) migrate() error {
|
|
slog.Debug("running migration: sys-stat", "result", r.db.AutoMigrate(&SysStat{}))
|
|
slog.Debug("running migration: user", "result", r.db.AutoMigrate(&domain.User{}))
|
|
slog.Debug("running migration: user webauthn credentials", "result",
|
|
r.db.AutoMigrate(&domain.UserWebauthnCredential{}))
|
|
slog.Debug("running migration: interface", "result", r.db.AutoMigrate(&domain.Interface{}))
|
|
slog.Debug("running migration: peer", "result", r.db.AutoMigrate(&domain.Peer{}))
|
|
slog.Debug("running migration: peer status", "result", r.db.AutoMigrate(&domain.PeerStatus{}))
|
|
slog.Debug("running migration: interface status", "result", r.db.AutoMigrate(&domain.InterfaceStatus{}))
|
|
slog.Debug("running migration: audit data", "result", r.db.AutoMigrate(&domain.AuditEntry{}))
|
|
slog.Debug("running migration: node sync lock", "result", r.db.AutoMigrate(&NodeSyncLock{}))
|
|
|
|
// Clean up deprecated columns from peer_statuses table (traffic accumulation refactor)
|
|
// Previously we had PreviousSessionBytesReceived/Transmitted columns, now we use a simpler approach
|
|
// These columns are safe to drop if they exist
|
|
if r.db.Migrator().HasColumn("peer_statuses", "previous_session_received") {
|
|
slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "previous_session_received")
|
|
r.db.Migrator().DropColumn("peer_statuses", "previous_session_received")
|
|
}
|
|
if r.db.Migrator().HasColumn("peer_statuses", "previous_session_transmitted") {
|
|
slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "previous_session_transmitted")
|
|
r.db.Migrator().DropColumn("peer_statuses", "previous_session_transmitted")
|
|
}
|
|
|
|
// Clean up accumulated traffic columns (no longer needed - use current session traffic only)
|
|
// Simplified to keep only current session bytes which are accurate from WireGuard
|
|
if r.db.Migrator().HasColumn("peer_statuses", "accumulated_received") {
|
|
slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "accumulated_received")
|
|
r.db.Migrator().DropColumn("peer_statuses", "accumulated_received")
|
|
}
|
|
if r.db.Migrator().HasColumn("peer_statuses", "accumulated_transmitted") {
|
|
slog.Debug("dropping deprecated column", "table", "peer_statuses", "column", "accumulated_transmitted")
|
|
r.db.Migrator().DropColumn("peer_statuses", "accumulated_transmitted")
|
|
}
|
|
|
|
existingSysStat := SysStat{}
|
|
r.db.Where("schema_version = ?", SchemaVersion).First(&existingSysStat)
|
|
if existingSysStat.SchemaVersion == 0 {
|
|
sysStat := SysStat{
|
|
MigratedAt: time.Now(),
|
|
SchemaVersion: SchemaVersion,
|
|
}
|
|
if err := r.db.Create(&sysStat).Error; err != nil {
|
|
return fmt.Errorf("failed to write sysstat entry for schema version %d: %w", SchemaVersion, err)
|
|
}
|
|
slog.Debug("sys-stat entry written", "schema_version", SchemaVersion)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// region interfaces
|
|
|
|
// GetInterface returns the interface with the given id.
|
|
// If no interface is found, an error domain.ErrNotFound is returned.
|
|
func (r *SqlRepo) GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) {
|
|
var in domain.Interface
|
|
|
|
err := r.db.WithContext(ctx).Preload("Addresses").First(&in, id).Error
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &in, nil
|
|
}
|
|
|
|
// GetInterfaceAndPeers returns the interface with the given id and all peers associated with it.
|
|
// If no interface is found, an error domain.ErrNotFound is returned.
|
|
func (r *SqlRepo) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (
|
|
*domain.Interface,
|
|
[]domain.Peer,
|
|
error,
|
|
) {
|
|
in, err := r.GetInterface(ctx, id)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to load interface: %w", err)
|
|
}
|
|
|
|
peers, err := r.GetInterfacePeers(ctx, id)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to load peers: %w", err)
|
|
}
|
|
|
|
return in, peers, nil
|
|
}
|
|
|
|
// GetPeersStats returns the stats for the given peer ids. The order of the returned stats is not guaranteed.
|
|
func (r *SqlRepo) GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) {
|
|
if len(ids) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
var stats []domain.PeerStatus
|
|
|
|
err := r.db.WithContext(ctx).Where("identifier IN ?", ids).Find(&stats).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return stats, nil
|
|
}
|
|
|
|
// GetAllInterfaces returns all interfaces.
|
|
func (r *SqlRepo) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) {
|
|
var interfaces []domain.Interface
|
|
|
|
err := r.db.WithContext(ctx).Preload("Addresses").Find(&interfaces).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return interfaces, nil
|
|
}
|
|
|
|
// GetInterfaceStats returns the stats for the given interface id.
|
|
// If no stats are found, an error domain.ErrNotFound is returned.
|
|
func (r *SqlRepo) GetInterfaceStats(ctx context.Context, id domain.InterfaceIdentifier) (
|
|
*domain.InterfaceStatus,
|
|
error,
|
|
) {
|
|
if id == "" {
|
|
return nil, nil
|
|
}
|
|
|
|
var stats []domain.InterfaceStatus
|
|
|
|
err := r.db.WithContext(ctx).Where("identifier = ?", id).Find(&stats).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(stats) == 0 {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
|
|
stat := stats[0]
|
|
|
|
return &stat, nil
|
|
}
|
|
|
|
// FindInterfaces returns all interfaces that match the given search string.
|
|
// The search string is matched against the interface identifier and display name.
|
|
func (r *SqlRepo) FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error) {
|
|
var users []domain.Interface
|
|
|
|
searchValue := "%" + strings.ToLower(search) + "%"
|
|
err := r.db.WithContext(ctx).
|
|
Where("identifier LIKE ?", searchValue).
|
|
Or("display_name LIKE ?", searchValue).
|
|
Preload("Addresses").
|
|
Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// SaveInterface updates the interface with the given id.
|
|
func (r *SqlRepo) SaveInterface(
|
|
ctx context.Context,
|
|
id domain.InterfaceIdentifier,
|
|
updateFunc func(in *domain.Interface) (*domain.Interface, error),
|
|
) error {
|
|
userInfo := domain.GetUserInfo(ctx)
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
in, err := r.getOrCreateInterface(userInfo, tx, id)
|
|
if err != nil {
|
|
return err // return any error will roll back
|
|
}
|
|
|
|
in, err = updateFunc(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = r.upsertInterface(userInfo, tx, in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// return nil will commit the whole transaction
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *SqlRepo) getOrCreateInterface(
|
|
ui *domain.ContextUserInfo,
|
|
tx *gorm.DB,
|
|
id domain.InterfaceIdentifier,
|
|
) (*domain.Interface, error) {
|
|
var in domain.Interface
|
|
|
|
// interfaceDefaults will be applied to newly created interface records
|
|
interfaceDefaults := domain.Interface{
|
|
BaseModel: domain.BaseModel{
|
|
CreatedBy: ui.UserId(),
|
|
UpdatedBy: ui.UserId(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
},
|
|
Identifier: id,
|
|
}
|
|
|
|
err := tx.Preload("Addresses").Attrs(interfaceDefaults).FirstOrCreate(&in, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &in, nil
|
|
}
|
|
|
|
func (r *SqlRepo) upsertInterface(ui *domain.ContextUserInfo, tx *gorm.DB, in *domain.Interface) error {
|
|
in.UpdatedBy = ui.UserId()
|
|
in.UpdatedAt = time.Now()
|
|
|
|
err := tx.Save(in).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Only update addresses if they were explicitly set (not nil)
|
|
// This prevents accidentally deleting all addresses when loading interface without preload
|
|
if in.Addresses != nil {
|
|
err = tx.Model(in).Association("Addresses").Replace(in.Addresses)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update interface addresses: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteInterface deletes the interface with the given id.
|
|
func (r *SqlRepo) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error {
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
err := tx.Where("interface_identifier = ?", id).Delete(&domain.Peer{}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Delete(&domain.InterfaceStatus{InterfaceId: id}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Select(clause.Associations).Delete(&domain.Interface{Identifier: id}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetInterfaceIps returns a map of interface identifiers to their respective IP addresses.
|
|
func (r *SqlRepo) GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) {
|
|
var ips []struct {
|
|
domain.Cidr
|
|
InterfaceId domain.InterfaceIdentifier `gorm:"column:interface_identifier"`
|
|
}
|
|
|
|
err := r.db.WithContext(ctx).
|
|
Table("interface_addresses").
|
|
Joins("LEFT JOIN cidrs ON interface_addresses.cidr_cidr = cidrs.cidr").
|
|
Scan(&ips).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make(map[domain.InterfaceIdentifier][]domain.Cidr)
|
|
for _, ip := range ips {
|
|
result[ip.InterfaceId] = append(result[ip.InterfaceId], ip.Cidr)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// endregion interfaces
|
|
|
|
// region peers
|
|
|
|
// GetPeer returns the peer with the given id.
|
|
// If no peer is found, an error domain.ErrNotFound is returned.
|
|
func (r *SqlRepo) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) {
|
|
var peer domain.Peer
|
|
|
|
err := r.db.WithContext(ctx).Preload("Addresses").First(&peer, id).Error
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &peer, nil
|
|
}
|
|
|
|
// GetInterfacePeers returns all peers associated with the given interface id.
|
|
func (r *SqlRepo) GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) {
|
|
var peers []domain.Peer
|
|
|
|
err := r.db.WithContext(ctx).Preload("Addresses").Where("interface_identifier = ?", id).Find(&peers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// FindInterfacePeers returns all peers associated with the given interface id that match the given search string.
|
|
// The search string is matched against the peer identifier, display name and IP address.
|
|
func (r *SqlRepo) FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) (
|
|
[]domain.Peer,
|
|
error,
|
|
) {
|
|
var peers []domain.Peer
|
|
|
|
searchValue := "%" + strings.ToLower(search) + "%"
|
|
err := r.db.WithContext(ctx).Where("interface_identifier = ?", id).
|
|
Where("identifier LIKE ?", searchValue).
|
|
Or("display_name LIKE ?", searchValue).
|
|
Or("iface_address_str_v LIKE ?", searchValue).
|
|
Find(&peers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// GetUserPeers returns all peers associated with the given user id.
|
|
func (r *SqlRepo) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
|
|
var peers []domain.Peer
|
|
|
|
err := r.db.WithContext(ctx).Preload("Addresses").Where("user_identifier = ?", id).Find(&peers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// FindUserPeers returns all peers associated with the given user id that match the given search string.
|
|
// The search string is matched against the peer identifier, display name and IP address.
|
|
func (r *SqlRepo) FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error) {
|
|
var peers []domain.Peer
|
|
|
|
searchValue := "%" + strings.ToLower(search) + "%"
|
|
err := r.db.WithContext(ctx).Where("user_identifier = ?", id).
|
|
Where("identifier LIKE ?", searchValue).
|
|
Or("display_name LIKE ?", searchValue).
|
|
Or("iface_address_str_v LIKE ?", searchValue).
|
|
Find(&peers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// SavePeer updates the peer with the given id.
|
|
// If no existing peer is found, a new peer is created.
|
|
// IMPORTANT: Also creates peer_status record to avoid deadlock during concurrent updates
|
|
func (r *SqlRepo) SavePeer(
|
|
ctx context.Context,
|
|
id domain.PeerIdentifier,
|
|
updateFunc func(in *domain.Peer) (*domain.Peer, error),
|
|
) error {
|
|
userInfo := domain.GetUserInfo(ctx)
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
peer, err := r.getOrCreatePeer(userInfo, tx, id)
|
|
if err != nil {
|
|
return err // return any error will roll back
|
|
}
|
|
|
|
peer, err = updateFunc(peer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = r.upsertPeer(userInfo, tx, peer)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// DEADLOCK FIX: Ensure peer_status record exists WITHOUT FOR UPDATE lock
|
|
// This prevents lock contention when stats collection immediately tries to update it
|
|
// Use explicit WHERE clause with correct column name (identifier, not peer_id)
|
|
peerStatus := domain.PeerStatus{
|
|
PeerId: id,
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
// Explicit WHERE using correct column name for FirstOrCreate
|
|
err = tx.Where("identifier = ?", id).FirstOrCreate(&peerStatus).Error
|
|
if err != nil {
|
|
// Log but don't fail - peer_status will be created on first stats update anyway
|
|
slog.Debug("peer status record already exists or creation deferred", "peer", id)
|
|
}
|
|
|
|
// return nil will commit the whole transaction
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *SqlRepo) getOrCreatePeer(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.PeerIdentifier) (
|
|
*domain.Peer,
|
|
error,
|
|
) {
|
|
var peer domain.Peer
|
|
|
|
// interfaceDefaults will be applied to newly created interface records
|
|
interfaceDefaults := domain.Peer{
|
|
BaseModel: domain.BaseModel{
|
|
CreatedBy: ui.UserId(),
|
|
UpdatedBy: ui.UserId(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
},
|
|
Identifier: id,
|
|
}
|
|
|
|
err := tx.Attrs(interfaceDefaults).FirstOrCreate(&peer, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &peer, nil
|
|
}
|
|
|
|
func (r *SqlRepo) upsertPeer(ui *domain.ContextUserInfo, tx *gorm.DB, peer *domain.Peer) error {
|
|
peer.UpdatedBy = ui.UserId()
|
|
peer.UpdatedAt = time.Now()
|
|
|
|
err := tx.Save(peer).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Model(peer).Association("Addresses").Replace(peer.Interface.Addresses)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update peer addresses: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeletePeer deletes the peer with the given id.
|
|
// This also deletes the peer_addresses associations (many-to-many relationships with Cidr)
|
|
// NOTE: We do NOT delete peer_status here.
|
|
// The peer_status will be cleaned up by CleanOrphanedStatuses on all cluster nodes.
|
|
// This ensures that other nodes can detect orphaned statuses and clean up their metrics.
|
|
func (r *SqlRepo) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error {
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// First: delete peer_addresses many2many associations to avoid foreign key issues
|
|
// This is critical because we use raw SQL and explicit deletion like in DeletePeersByIDs
|
|
if err := tx.Table("peer_addresses").Where("peer_identifier = ?", string(id)).Delete(nil).Error; err != nil {
|
|
return err
|
|
}
|
|
|
|
// Second: delete the peer itself
|
|
err := tx.Where("identifier = ?", string(id)).Delete(&domain.Peer{}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetPeerIps returns a map of peer identifiers to their respective IP addresses.
|
|
func (r *SqlRepo) GetPeerIps(ctx context.Context) (map[domain.PeerIdentifier][]domain.Cidr, error) {
|
|
var ips []struct {
|
|
domain.Cidr
|
|
PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
|
|
}
|
|
|
|
err := r.db.WithContext(ctx).
|
|
Table("peer_addresses").
|
|
Joins("LEFT JOIN cidrs ON peer_addresses.cidr_cidr = cidrs.cidr").
|
|
Scan(&ips).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
result := make(map[domain.PeerIdentifier][]domain.Cidr)
|
|
for _, ip := range ips {
|
|
result[ip.PeerId] = append(result[ip.PeerId], ip.Cidr)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// GetUsedIpsPerSubnet returns a map of subnets to their respective used IP addresses.
|
|
func (r *SqlRepo) GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (
|
|
map[domain.Cidr][]domain.Cidr,
|
|
error,
|
|
) {
|
|
var peerIps []struct {
|
|
domain.Cidr
|
|
PeerId domain.PeerIdentifier `gorm:"column:peer_identifier"`
|
|
}
|
|
|
|
err := r.db.WithContext(ctx).
|
|
Table("peer_addresses").
|
|
Joins("LEFT JOIN cidrs ON peer_addresses.cidr_cidr = cidrs.cidr").
|
|
Scan(&peerIps).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch peer IP's: %w", err)
|
|
}
|
|
|
|
var interfaceIps []struct {
|
|
domain.Cidr
|
|
InterfaceId domain.InterfaceIdentifier `gorm:"column:interface_identifier"`
|
|
}
|
|
|
|
err = r.db.WithContext(ctx).
|
|
Table("interface_addresses").
|
|
Joins("LEFT JOIN cidrs ON interface_addresses.cidr_cidr = cidrs.cidr").
|
|
Scan(&interfaceIps).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch interface IP's: %w", err)
|
|
}
|
|
|
|
result := make(map[domain.Cidr][]domain.Cidr, len(subnets))
|
|
for _, ip := range interfaceIps {
|
|
var subnet domain.Cidr // default empty subnet (if no subnet matches, we will add the IP to the empty subnet group)
|
|
for _, s := range subnets {
|
|
if s.Contains(ip.Cidr) {
|
|
subnet = s
|
|
break
|
|
}
|
|
}
|
|
result[subnet] = append(result[subnet], ip.Cidr)
|
|
}
|
|
for _, ip := range peerIps {
|
|
var subnet domain.Cidr // default empty subnet (if no subnet matches, we will add the IP to the empty subnet group)
|
|
for _, s := range subnets {
|
|
if s.Contains(ip.Cidr) {
|
|
subnet = s
|
|
break
|
|
}
|
|
}
|
|
result[subnet] = append(result[subnet], ip.Cidr)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// endregion peers
|
|
|
|
// region users
|
|
|
|
// GetUser returns the user with the given id.
|
|
// If no user is found, an error domain.ErrNotFound is returned.
|
|
func (r *SqlRepo) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) {
|
|
var user domain.User
|
|
|
|
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").First(&user, id).Error
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// GetUserByEmail returns the user with the given email.
|
|
// If no user is found, an error domain.ErrNotFound is returned.
|
|
// If multiple users are found, an error domain.ErrNotUnique is returned.
|
|
func (r *SqlRepo) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
|
var users []domain.User
|
|
|
|
err := r.db.WithContext(ctx).Where("email = ?", email).Preload("WebAuthnCredentialList").Find(&users).Error
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(users) == 0 {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
|
|
if len(users) > 1 {
|
|
return nil, fmt.Errorf("found multiple users with email %s: %w", email, domain.ErrNotUnique)
|
|
}
|
|
|
|
user := users[0]
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
// GetUserByWebAuthnCredential returns the user with the given webauthn credential id.
|
|
func (r *SqlRepo) GetUserByWebAuthnCredential(ctx context.Context, credentialIdBase64 string) (*domain.User, error) {
|
|
var credential domain.UserWebauthnCredential
|
|
|
|
err := r.db.WithContext(ctx).Where("credential_identifier = ?", credentialIdBase64).First(&credential).Error
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, domain.ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return r.GetUser(ctx, domain.UserIdentifier(credential.UserIdentifier))
|
|
}
|
|
|
|
// GetAllUsers returns all users.
|
|
func (r *SqlRepo) GetAllUsers(ctx context.Context) ([]domain.User, error) {
|
|
var users []domain.User
|
|
|
|
err := r.db.WithContext(ctx).Preload("WebAuthnCredentialList").Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// FindUsers returns all users that match the given search string.
|
|
// The search string is matched against the user identifier, firstname, lastname and email.
|
|
func (r *SqlRepo) FindUsers(ctx context.Context, search string) ([]domain.User, error) {
|
|
var users []domain.User
|
|
|
|
searchValue := "%" + strings.ToLower(search) + "%"
|
|
err := r.db.WithContext(ctx).
|
|
Where("identifier LIKE ?", searchValue).
|
|
Or("firstname LIKE ?", searchValue).
|
|
Or("lastname LIKE ?", searchValue).
|
|
Or("email LIKE ?", searchValue).
|
|
Preload("WebAuthnCredentialList").
|
|
Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return users, nil
|
|
}
|
|
|
|
// SaveUser updates the user with the given id.
|
|
// If no user is found, a new user is created.
|
|
func (r *SqlRepo) SaveUser(
|
|
ctx context.Context,
|
|
id domain.UserIdentifier,
|
|
updateFunc func(u *domain.User) (*domain.User, error),
|
|
) error {
|
|
userInfo := domain.GetUserInfo(ctx)
|
|
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
user, err := r.getOrCreateUser(userInfo, tx, id)
|
|
if err != nil {
|
|
return err // return any error will roll back
|
|
}
|
|
|
|
user, err = updateFunc(user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = r.upsertUser(userInfo, tx, user)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// return nil will commit the whole transaction
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeleteUser deletes the user with the given id.
|
|
func (r *SqlRepo) DeleteUser(ctx context.Context, id domain.UserIdentifier) error {
|
|
err := r.db.WithContext(ctx).Unscoped().Select(clause.Associations).Delete(&domain.User{Identifier: id}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *SqlRepo) getOrCreateUser(ui *domain.ContextUserInfo, tx *gorm.DB, id domain.UserIdentifier) (
|
|
*domain.User,
|
|
error,
|
|
) {
|
|
var user domain.User
|
|
|
|
// userDefaults will be applied to newly created user records
|
|
userDefaults := domain.User{
|
|
BaseModel: domain.BaseModel{
|
|
CreatedBy: ui.UserId(),
|
|
UpdatedBy: ui.UserId(),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
},
|
|
Identifier: id,
|
|
Source: domain.UserSourceDatabase,
|
|
IsAdmin: false,
|
|
}
|
|
|
|
err := tx.Attrs(userDefaults).FirstOrCreate(&user, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &user, nil
|
|
}
|
|
|
|
func (r *SqlRepo) upsertUser(ui *domain.ContextUserInfo, tx *gorm.DB, user *domain.User) error {
|
|
user.UpdatedBy = ui.UserId()
|
|
user.UpdatedAt = time.Now()
|
|
|
|
err := tx.Save(user).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = tx.Session(&gorm.Session{FullSaveAssociations: true}).Unscoped().Model(user).Association("WebAuthnCredentialList").Unscoped().Replace(user.WebAuthnCredentialList)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update users webauthn credentials: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion users
|
|
|
|
// region statistics
|
|
|
|
// UpdateInterfaceStatus updates the interface status with the given id.
|
|
// If no interface status is found, a new one is created.
|
|
func (r *SqlRepo) UpdateInterfaceStatus(
|
|
ctx context.Context,
|
|
id domain.InterfaceIdentifier,
|
|
updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error),
|
|
) error {
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
in, err := r.getOrCreateInterfaceStatus(tx, id)
|
|
if err != nil {
|
|
return err // return any error will roll back
|
|
}
|
|
|
|
in, err = updateFunc(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = r.upsertInterfaceStatus(tx, in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// return nil will commit the whole transaction
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *SqlRepo) getOrCreateInterfaceStatus(tx *gorm.DB, id domain.InterfaceIdentifier) (
|
|
*domain.InterfaceStatus,
|
|
error,
|
|
) {
|
|
var in domain.InterfaceStatus
|
|
|
|
// defaults will be applied to newly created record
|
|
defaults := domain.InterfaceStatus{
|
|
InterfaceId: id,
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
err := tx.Attrs(defaults).FirstOrCreate(&in, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &in, nil
|
|
}
|
|
|
|
func (r *SqlRepo) upsertInterfaceStatus(tx *gorm.DB, in *domain.InterfaceStatus) error {
|
|
err := tx.Save(in).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdatePeerStatus updates the peer status with the given id.
|
|
// If no peer status is found, a new one is created.
|
|
// OWNERSHIP CHECK: If peer has an OwnerNodeId set, skip update to prevent conflicts with owner node.
|
|
// Includes automatic retry logic with exponential backoff for deadlock/conflict recovery.
|
|
// Uses row-level FOR UPDATE locking to prevent concurrent modification conflicts.
|
|
func (r *SqlRepo) UpdatePeerStatus(
|
|
ctx context.Context,
|
|
id domain.PeerIdentifier,
|
|
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
|
) error {
|
|
// Retry logic with exponential backoff for deadlocks and conflicts
|
|
// Increased to 10 retries (50ms, 100ms, 200ms, 400ms, 800ms, 1.6s, 3.2s, 6.4s, 12.8s, 25.6s)
|
|
// This handles high-concurrency scenarios with 24 cluster nodes and 100+ peers
|
|
// Total max wait time: ~50 seconds for all retries
|
|
const maxRetries = 10
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
if attempt > 0 {
|
|
// Exponential backoff with jitter to reduce synchronized retries
|
|
// Base: 50ms * 2^(attempt-1) + random 0-50ms jitter
|
|
baseWait := time.Duration(50*(1<<uint(attempt-1))) * time.Millisecond
|
|
jitter := time.Duration(rand.Intn(50)) * time.Millisecond
|
|
waitTime := baseWait + jitter
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(waitTime):
|
|
}
|
|
}
|
|
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
in, err := r.getOrCreatePeerStatus(tx, id)
|
|
if err != nil {
|
|
return err // return any error will roll back
|
|
}
|
|
|
|
// Save original state to restore conflicting fields if peer is owned by another node
|
|
oldIsConnected := in.IsConnected
|
|
oldIsPingable := in.IsPingable
|
|
oldOwnerNodeId := in.OwnerNodeId
|
|
|
|
// IMPORTANT: Always allow updateFunc to proceed!
|
|
// This ensures non-conflicting data (BytesReceived, BytesTransmitted, LastHandshake, Endpoint)
|
|
// is captured for ALL nodes, not just the owner.
|
|
// We'll selectively restore conflicting fields below if owned by another node.
|
|
in, err = updateFunc(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// OWNERSHIP CHECK: If peer is CONNECTED and owned by another node,
|
|
// restore the conflicting state fields to prevent this node from overwriting owner's state
|
|
// BUT keep non-conflicting data like bytes and endpoint
|
|
if oldIsConnected && oldOwnerNodeId != "" && oldOwnerNodeId != r.cfg.Core.ClusterNodeId {
|
|
slog.Info("peer status owned by other node, restoring conflicting fields but KEEPING bytes",
|
|
"peer", id,
|
|
"owner", oldOwnerNodeId,
|
|
"our_node", r.cfg.Core.ClusterNodeId,
|
|
"restored_is_connected", oldIsConnected,
|
|
"bytes_received", in.BytesReceived,
|
|
"bytes_transmitted", in.BytesTransmitted,
|
|
"last_handshake", in.LastHandshake,
|
|
"endpoint", in.Endpoint)
|
|
// Restore conflicting fields to prevent state conflicts
|
|
in.IsConnected = oldIsConnected
|
|
in.IsPingable = oldIsPingable
|
|
in.OwnerNodeId = oldOwnerNodeId
|
|
// Keep the non-conflicting data:
|
|
// - BytesReceived, BytesTransmitted (traffic data)
|
|
// - LastHandshake, Endpoint (connection details)
|
|
// - LastSessionStart, LastPing (timestamps)
|
|
} else {
|
|
// Peer is either offline or not owned by another node - safe to save all data
|
|
slog.Debug("peer status update - saving all data",
|
|
"peer", id,
|
|
"is_connected", in.IsConnected,
|
|
"owned_by", in.OwnerNodeId,
|
|
"bytes_received", in.BytesReceived,
|
|
"bytes_transmitted", in.BytesTransmitted)
|
|
}
|
|
|
|
// CRITICAL: When peer transitions to OFFLINE, clear its ownership
|
|
// This allows any node to update it in future cycles
|
|
if !in.IsConnected && in.OwnerNodeId != "" {
|
|
slog.Debug("clearing peer ownership on offline transition",
|
|
"peer", id, "old_owner", in.OwnerNodeId)
|
|
in.OwnerNodeId = ""
|
|
}
|
|
|
|
err = r.upsertPeerStatus(tx, in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// return nil will commit the whole transaction
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
return nil // success
|
|
}
|
|
|
|
lastErr = err
|
|
// Check if this is a retryable error (deadlock or record changed)
|
|
errMsg := err.Error()
|
|
if !strings.Contains(errMsg, "Deadlock") && !strings.Contains(errMsg, "Record has changed") {
|
|
return err // not a retryable error
|
|
}
|
|
// Continue to retry
|
|
}
|
|
|
|
return fmt.Errorf("UpdatePeerStatus failed after %d retries: %w", maxRetries, lastErr)
|
|
}
|
|
|
|
// ClaimPeerStatus claims ownership of a peer status for this node.
|
|
// Only the owner node should update peer status to avoid conflicts.
|
|
// Once claimed, only this node (and nodes with same owner_node_id) can update it.
|
|
// Uses row-level FOR UPDATE locking to serialize concurrent claims.
|
|
func (r *SqlRepo) ClaimPeerStatus(
|
|
ctx context.Context,
|
|
id domain.PeerIdentifier,
|
|
ownerNodeId string,
|
|
updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
|
) error {
|
|
if ownerNodeId == "" {
|
|
return fmt.Errorf("ownerNodeId cannot be empty")
|
|
}
|
|
|
|
// Retry logic with exponential backoff for deadlocks and conflicts
|
|
// Increased to 10 retries for high-concurrency multi-node ownership claiming
|
|
// Total max wait time: ~50 seconds for all retries
|
|
const maxRetries = 10
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
if attempt > 0 {
|
|
// Exponential backoff with jitter to reduce synchronized retries
|
|
// Base: 50ms * 2^(attempt-1) + random 0-50ms jitter
|
|
baseWait := time.Duration(50*(1<<uint(attempt-1))) * time.Millisecond
|
|
jitter := time.Duration(rand.Intn(50)) * time.Millisecond
|
|
waitTime := baseWait + jitter
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(waitTime):
|
|
}
|
|
}
|
|
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
in, err := r.getOrCreatePeerStatus(tx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Claim ownership: set before calling updateFunc
|
|
in.OwnerNodeId = ownerNodeId
|
|
in.UpdatedAt = time.Now()
|
|
|
|
// Apply the custom update function
|
|
in, err = updateFunc(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Ensure ownership is maintained
|
|
in.OwnerNodeId = ownerNodeId
|
|
|
|
err = r.upsertPeerStatus(tx, in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
return nil // success
|
|
}
|
|
|
|
lastErr = err
|
|
// Check if this is a retryable error (deadlock or record changed)
|
|
errMsg := err.Error()
|
|
if !strings.Contains(errMsg, "Deadlock") && !strings.Contains(errMsg, "Record has changed") {
|
|
return err // not a retryable error
|
|
}
|
|
// Continue to retry
|
|
}
|
|
|
|
return fmt.Errorf("ClaimPeerStatus failed after %d retries: %w", maxRetries, lastErr)
|
|
}
|
|
|
|
// BatchUpdatePeerStatuses updates multiple peer statuses in a single optimized transaction.
|
|
// This is more efficient than calling UpdatePeerStatus individually and reduces deadlock risk.
|
|
// OWNERSHIP CHECK: Skips peers owned by other nodes to prevent conflicts.
|
|
// Uses ON CONFLICT DO UPDATE for bulk upsert with automatic conflict resolution.
|
|
func (r *SqlRepo) BatchUpdatePeerStatuses(
|
|
ctx context.Context,
|
|
updates map[domain.PeerIdentifier]func(in *domain.PeerStatus) (*domain.PeerStatus, error),
|
|
) error {
|
|
if len(updates) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Retry logic for batch operations - more retries for bulk operations
|
|
const maxRetries = 5
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
|
if attempt > 0 {
|
|
// Exponential backoff: 50ms, 100ms, 200ms, 400ms, 800ms
|
|
waitTime := time.Duration(50*(1<<uint(attempt-1))) * time.Millisecond
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-time.After(waitTime):
|
|
}
|
|
}
|
|
|
|
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
|
var allStatuses []*domain.PeerStatus
|
|
|
|
// Load all peer statuses we need to update
|
|
for id := range updates {
|
|
in, err := r.getOrCreatePeerStatus(tx, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// OWNERSHIP CHECK: Skip peers owned by other nodes
|
|
if in.OwnerNodeId != "" {
|
|
// Peer is owned by another node - skip update
|
|
slog.Debug("peer status owned by other node, skipping batch update",
|
|
"peer", id, "owner", in.OwnerNodeId)
|
|
continue
|
|
}
|
|
|
|
// Apply the update function
|
|
updateFunc := updates[id]
|
|
in, err = updateFunc(in)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
allStatuses = append(allStatuses, in)
|
|
}
|
|
|
|
// Batch insert/update all statuses in one operation
|
|
if len(allStatuses) > 0 {
|
|
err := tx.Clauses(clause.OnConflict{
|
|
UpdateAll: true,
|
|
}).CreateInBatches(allStatuses, 50).Error // batch insert in groups of 50
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err == nil {
|
|
return nil // success
|
|
}
|
|
|
|
lastErr = err
|
|
// Check if this is a retryable error
|
|
errMsg := err.Error()
|
|
if !strings.Contains(errMsg, "Deadlock") && !strings.Contains(errMsg, "Record has changed") {
|
|
return err // not a retryable error
|
|
}
|
|
// Continue to retry
|
|
}
|
|
|
|
return fmt.Errorf("BatchUpdatePeerStatuses failed after %d retries: %w", maxRetries, lastErr)
|
|
}
|
|
|
|
func (r *SqlRepo) getOrCreatePeerStatus(tx *gorm.DB, id domain.PeerIdentifier) (*domain.PeerStatus, error) {
|
|
var in domain.PeerStatus
|
|
|
|
// defaults will be applied to newly created record
|
|
defaults := domain.PeerStatus{
|
|
PeerId: id,
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
// DEADLOCK FIX: Two-phase approach to reduce lock contention
|
|
// 1. First try to get existing record with FOR UPDATE lock
|
|
// (This avoids gap locks on non-existent rows)
|
|
err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("identifier = ?", id).First(&in).Error
|
|
if err == nil {
|
|
// Record found and locked - proceed with update
|
|
return &in, nil
|
|
}
|
|
|
|
if err != gorm.ErrRecordNotFound {
|
|
// Unexpected error
|
|
return nil, err
|
|
}
|
|
|
|
// Record doesn't exist - create it WITHOUT FOR UPDATE lock
|
|
// This is safe because SavePeer pre-creates peer_status, so this is rare
|
|
// Explicit WHERE clause using correct column name
|
|
err = tx.Where("identifier = ?", id).Attrs(defaults).FirstOrCreate(&in).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Now that record exists, acquire FOR UPDATE lock for modifications
|
|
err = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("identifier = ?", id).First(&in).Error
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lock peer status after creation: %w", err)
|
|
}
|
|
|
|
return &in, nil
|
|
}
|
|
|
|
func (r *SqlRepo) getOrCreatePeerStatusForRead(tx *gorm.DB, id domain.PeerIdentifier) (*domain.PeerStatus, error) {
|
|
var in domain.PeerStatus
|
|
|
|
// defaults will be applied to newly created record
|
|
defaults := domain.PeerStatus{
|
|
PeerId: id,
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
// For read-only access without modifications (used in conditionals)
|
|
// No FOR UPDATE lock - allows concurrent reads
|
|
err := tx.Attrs(defaults).FirstOrCreate(&in, id).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &in, nil
|
|
}
|
|
|
|
func (r *SqlRepo) upsertPeerStatus(tx *gorm.DB, in *domain.PeerStatus) error {
|
|
err := tx.Save(in).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// DeletePeerStatus deletes the peer status with the given id.
|
|
func (r *SqlRepo) DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error {
|
|
err := r.db.WithContext(ctx).Delete(&domain.PeerStatus{}, id).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAllPeerStatuses returns all peer statuses from the database.
|
|
func (r *SqlRepo) GetAllPeerStatuses(ctx context.Context) ([]domain.PeerStatus, error) {
|
|
var statuses []domain.PeerStatus
|
|
|
|
err := r.db.WithContext(ctx).Find(&statuses).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return statuses, nil
|
|
}
|
|
|
|
// endregion statistics
|
|
|
|
// region audit
|
|
|
|
// SaveAuditEntry saves the given audit entry.
|
|
func (r *SqlRepo) SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error {
|
|
err := r.db.WithContext(ctx).Save(entry).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAllAuditEntries retrieves all audit entries from the database.
|
|
// The entries are ordered by timestamp, with the newest entries first.
|
|
func (r *SqlRepo) GetAllAuditEntries(ctx context.Context) ([]domain.AuditEntry, error) {
|
|
var entries []domain.AuditEntry
|
|
err := r.db.WithContext(ctx).Order("created_at desc").Find(&entries).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return entries, nil
|
|
}
|
|
|
|
// endregion audit
|
|
|
|
func (r *SqlRepo) GetAllPeers(ctx context.Context) ([]domain.Peer, error) {
|
|
var peers []domain.Peer
|
|
// OPTIMIZED: Select only peer identifiers and basic fields to avoid N+1 queries
|
|
// from Preload("Addresses") and Preload("Interface") which would create 600+ queries
|
|
// This method is used mainly for cleanup validation, not for full peer data display
|
|
err := r.db.WithContext(ctx).
|
|
Select("id, identifier, display_name, interface_identifier").
|
|
Find(&peers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// region node sync lock
|
|
|
|
const (
|
|
syncLockKey = "peer:sync"
|
|
syncLockDuration = 5 * time.Minute // Auto-release stuck locks after 5 minutes
|
|
syncLockTimeout = 2 * time.Minute // Total wait time for acquiring lock
|
|
)
|
|
|
|
// AcquireSyncLock attempts to acquire the global peer sync lock.
|
|
// Returns nodeID that holds the lock, or empty string if we acquired it.
|
|
// Uses exponential backoff to retry up to syncLockTimeout.
|
|
func (r *SqlRepo) AcquireSyncLock(ctx context.Context, nodeID string) (acquiredBy string, err error) {
|
|
// Use generic lock mechanism
|
|
return r.AcquireLock(ctx, syncLockKey, nodeID, syncLockDuration)
|
|
}
|
|
|
|
// ReleaseSyncLock releases the global peer sync lock.
|
|
func (r *SqlRepo) ReleaseSyncLock(ctx context.Context, nodeID string) error {
|
|
return r.ReleaseLock(ctx, syncLockKey, nodeID)
|
|
}
|
|
|
|
// SyncAllPeersFromDBWithLock wraps sync to ensure only one node syncs at a time.
|
|
// This prevents database contention and 504 Gateway Timeout cascades.
|
|
func (r *SqlRepo) SyncAllPeersFromDBWithLock(ctx context.Context, nodeID string) (int, error) {
|
|
// Try to acquire lock with timeout
|
|
lockCtx, cancel := context.WithTimeout(ctx, syncLockTimeout)
|
|
defer cancel()
|
|
|
|
heldBy, err := r.AcquireSyncLock(lockCtx, nodeID)
|
|
if err != nil {
|
|
slog.Warn("[SYNC_LOCK] cannot acquire sync lock, another node is syncing",
|
|
"held_by", heldBy, "self", nodeID, "error", err)
|
|
// Return 0,0 to indicate no sync performed (not an error condition)
|
|
return 0, nil
|
|
}
|
|
|
|
// Ensure lock is released after we're done (even if sync fails)
|
|
defer func() {
|
|
if relErr := r.ReleaseSyncLock(context.Background(), nodeID); relErr != nil {
|
|
slog.Error("failed to release sync lock", "error", relErr)
|
|
}
|
|
// Clean up expired locks in background to prevent table bloat
|
|
go func() {
|
|
goCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
if cleanErr := r.CleanupExpiredLocks(goCtx); cleanErr != nil {
|
|
slog.Debug("cleanup expired locks failed", "error", cleanErr)
|
|
}
|
|
}()
|
|
}()
|
|
|
|
// Now perform the actual sync with the lock held
|
|
count, err := r.SyncAllPeersFromDB(ctx)
|
|
if err != nil {
|
|
slog.Error("[SYNC_LOCK] sync failed while holding lock", "node_id", nodeID, "error", err)
|
|
return count, err
|
|
}
|
|
|
|
slog.Info("[SYNC_LOCK] sync completed", "node_id", nodeID, "synced_count", count)
|
|
return count, nil
|
|
}
|
|
|
|
// endregion node sync lock
|
|
|
|
// SyncAllPeersFromDB synchronizes all peers from the database.
|
|
func (r *SqlRepo) SyncAllPeersFromDB(ctx context.Context) (int, error) {
|
|
slog.Debug("SyncAllPeersFromDB called, but no operation performed")
|
|
return 0, nil
|
|
}
|
|
|
|
// region expired peers cleanup
|
|
|
|
const (
|
|
expireCleanupLockKey = "expire:cleanup"
|
|
expireCleanupLockDuration = 10 * time.Minute // Lock to ensure only one node performs cleanup
|
|
expireCleanupTimeout = 30 * time.Second // Timeout for acquiring lock
|
|
)
|
|
|
|
// GetExpiredPeers finds all peers with expiredAt in the past
|
|
func (r *SqlRepo) GetExpiredPeers(ctx context.Context) ([]domain.Peer, error) {
|
|
var peers []domain.Peer
|
|
now := time.Now()
|
|
|
|
err := r.db.WithContext(ctx).
|
|
Where("expires_at IS NOT NULL AND expires_at < ?", now).
|
|
Select("identifier, interface_identifier").
|
|
Find(&peers).Error
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return peers, nil
|
|
}
|
|
|
|
// DeletePeersByIDs deletes peers by their Identifier (public key) and their associated addresses
|
|
// Used when cleaning up expired peers
|
|
// IMPORTANT: Deletes peer_addresses associations first to avoid foreign key constraint violations
|
|
// Also deletes peer statuses to avoid orphaned status records
|
|
func (r *SqlRepo) DeletePeersByIDs(ctx context.Context, peerIDs []string) (int64, error) {
|
|
if len(peerIDs) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// Start transaction to ensure atomic deletion (associations first, then peers, then statuses)
|
|
tx := r.db.WithContext(ctx).Begin()
|
|
if tx.Error != nil {
|
|
return 0, tx.Error
|
|
}
|
|
|
|
// First: delete peer_addresses many2many associations via raw SQL
|
|
if err := tx.Table("peer_addresses").Where("peer_identifier IN ?", peerIDs).Delete(nil).Error; err != nil {
|
|
tx.Rollback()
|
|
return 0, err
|
|
}
|
|
|
|
// Second: delete peer statuses to avoid orphaned status records
|
|
if err := tx.Where("identifier IN ?", peerIDs).Delete(&domain.PeerStatus{}).Error; err != nil {
|
|
tx.Rollback()
|
|
return 0, err
|
|
}
|
|
|
|
// Third: delete peers
|
|
result := tx.Where("identifier IN ?", peerIDs).Delete(&domain.Peer{})
|
|
if result.Error != nil {
|
|
tx.Rollback()
|
|
return 0, result.Error
|
|
}
|
|
|
|
// Commit transaction
|
|
if err := tx.Commit().Error; err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return result.RowsAffected, nil
|
|
}
|
|
|
|
// FindAndDeleteExpiredPeersWithLock ensures only MASTER node deletes expired peers
|
|
// Other nodes skip cleanup to avoid conflicts across multiple nodes
|
|
// If this node is not configured as master, cleanup is skipped
|
|
func (r *SqlRepo) FindAndDeleteExpiredPeersWithLock(ctx context.Context, nodeID string) (expiredPeerIDs []string, err error) {
|
|
// Only MASTER node can delete expired peers
|
|
if !r.cfg.Core.Master {
|
|
slog.Debug("[EXPIRE_CLEANUP] this node is not master, skipping cleanup", "node_id", nodeID)
|
|
return nil, nil
|
|
}
|
|
|
|
// Find expired peers
|
|
expiredPeers, err := r.GetExpiredPeers(ctx)
|
|
if err != nil {
|
|
slog.Error("[EXPIRE_CLEANUP] failed to get expired peers", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
if len(expiredPeers) == 0 {
|
|
slog.Debug("[EXPIRE_CLEANUP] no expired peers found")
|
|
return nil, nil
|
|
}
|
|
|
|
// Delete them from DB
|
|
peerIDs := make([]string, len(expiredPeers))
|
|
for i, p := range expiredPeers {
|
|
peerIDs[i] = string(p.Identifier)
|
|
}
|
|
|
|
deletedCount, err := r.DeletePeersByIDs(ctx, peerIDs)
|
|
if err != nil {
|
|
slog.Error("[EXPIRE_CLEANUP] failed to delete expired peers", "error", err, "count", len(peerIDs))
|
|
return nil, err
|
|
}
|
|
|
|
slog.Info("[EXPIRE_CLEANUP] deleted expired peers",
|
|
"node_id", nodeID, "count", deletedCount, "peer_ids_count", len(peerIDs))
|
|
|
|
return peerIDs, nil
|
|
}
|
|
|
|
// endregion expired peers cleanup
|
|
|
|
// region generic distributed locking (for reuse)
|
|
|
|
// AcquireLock is a generic lock mechanism for various operations.
|
|
// IMPORTANT: Uses simple INSERT without ON DUPLICATE KEY UPDATE to avoid deadlocks.
|
|
// If lock already exists and is not expired, returns who holds it (no retry loop).
|
|
func (r *SqlRepo) AcquireLock(ctx context.Context, lockKey string, nodeID string, duration time.Duration) (acquiredBy string, err error) {
|
|
now := time.Now()
|
|
|
|
// Clean up expired locks (but don't retry on failure)
|
|
if delErr := r.db.WithContext(ctx).
|
|
Where("lock_key = ? AND expires_at < ?", lockKey, now).
|
|
Delete(&NodeSyncLock{}).Error; delErr != nil {
|
|
slog.Debug("failed to clean expired locks", "lock_key", lockKey, "error", delErr)
|
|
}
|
|
|
|
// Try simple INSERT (no ON DUPLICATE KEY UPDATE to avoid deadlock)
|
|
result := r.db.WithContext(ctx).
|
|
Create(&NodeSyncLock{
|
|
LockKey: lockKey,
|
|
NodeID: nodeID,
|
|
LockedAt: now,
|
|
ExpiresAt: now.Add(duration),
|
|
})
|
|
|
|
if result.Error == nil {
|
|
slog.Debug("[LOCK] acquired", "lock_key", lockKey, "node_id", nodeID)
|
|
return "", nil
|
|
}
|
|
|
|
// INSERT failed - check who holds the lock (no retry, just check)
|
|
var lock NodeSyncLock
|
|
if err := r.db.WithContext(ctx).
|
|
Where("lock_key = ?", lockKey).
|
|
First(&lock).Error; err == nil {
|
|
if lock.ExpiresAt.Before(now) {
|
|
// Lock expired, clean it up (best effort)
|
|
r.db.WithContext(ctx).Delete(&lock)
|
|
return "", fmt.Errorf("lock expired, please retry")
|
|
}
|
|
// Lock is held by another node
|
|
acquiredBy = lock.NodeID
|
|
return acquiredBy, fmt.Errorf("lock held by %s", acquiredBy)
|
|
}
|
|
|
|
// Couldn't determine lock holder
|
|
return "", fmt.Errorf("failed to acquire lock %s", lockKey)
|
|
}
|
|
|
|
// ReleaseLock releases a distributed lock
|
|
func (r *SqlRepo) ReleaseLock(ctx context.Context, lockKey string, nodeID string) error {
|
|
// Only delete lock if it belongs to us AND it hasn't expired yet
|
|
// This prevents deadlocks from competing DELETE operations on expired locks
|
|
result := r.db.WithContext(ctx).
|
|
Where("lock_key = ? AND node_id = ? AND expires_at > NOW()", lockKey, nodeID).
|
|
Delete(&NodeSyncLock{})
|
|
|
|
if result.Error != nil {
|
|
// Log but don't fail - lock might already be released or expired
|
|
slog.Debug("failed to release lock", "lockKey", lockKey, "nodeID", nodeID, "error", result.Error)
|
|
return nil
|
|
}
|
|
|
|
if result.RowsAffected == 0 {
|
|
// Lock was already released or expired - not an error
|
|
slog.Debug("lock already released or expired", "lockKey", lockKey, "nodeID", nodeID)
|
|
return nil
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CleanupExpiredLocks removes all expired locks from the database
|
|
// This prevents deadlocks from piling up expired lock rows
|
|
func (r *SqlRepo) CleanupExpiredLocks(ctx context.Context) error {
|
|
result := r.db.WithContext(ctx).
|
|
Where("expires_at < ?", time.Now()).
|
|
Delete(&NodeSyncLock{})
|
|
|
|
if result.Error != nil {
|
|
slog.Warn("failed to cleanup expired locks", "error", result.Error)
|
|
return result.Error
|
|
}
|
|
|
|
if result.RowsAffected > 0 {
|
|
slog.Debug("cleaned up expired locks", "rows_deleted", result.RowsAffected)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion generic distributed locking
|