From 7d0da4e7ad79d9be4fc260132abeec3e30a55f71 Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Sun, 23 Mar 2025 23:09:47 +0100 Subject: [PATCH] chore: use interfaces for all other services --- cmd/wg-portal/main.go | 29 +-- .../app/api/v0/backend/interface_service.go | 91 ++++++++++ internal/app/api/v0/backend/peer_service.go | 112 ++++++++++++ internal/app/api/v0/backend/user_service.go | 83 +++++++++ internal/app/app.go | 92 +++++----- internal/app/audit/recorder.go | 24 ++- internal/app/audit/repos.go | 11 -- internal/app/auth/auth.go | 81 +++++++-- internal/app/auth/auth_ldap.go | 8 + internal/app/auth/auth_oauth.go | 13 +- internal/app/auth/auth_oidc.go | 12 +- internal/app/configfile/manager.go | 61 ++++++- internal/app/configfile/repos.go | 22 --- internal/app/configfile/template.go | 4 + internal/app/mail/manager.go | 51 +++++- internal/app/mail/repos.go | 28 --- internal/app/mail/template.go | 3 + internal/app/repos.go | 77 -------- internal/app/route/repos.go | 12 -- internal/app/route/routes.go | 31 +++- internal/app/users/repos.go | 20 --- internal/app/users/user_manager.go | 50 +++++- internal/app/wireguard/repos.go | 87 --------- internal/app/wireguard/statistics.go | 48 ++++- internal/app/wireguard/wireguard.go | 71 +++++++- .../app/wireguard/wireguard_interfaces.go | 13 ++ internal/app/wireguard/wireguard_peers.go | 11 ++ internal/config/auth.go | 7 + internal/config/config.go | 2 + internal/config/database.go | 3 + internal/config/mail.go | 5 + internal/config/web.go | 1 + internal/domain/auth.go | 31 ---- internal/domain/crypto.go | 4 + internal/domain/crypto_test.go | 56 ++++++ internal/domain/interface_test.go | 83 +++++++++ internal/domain/options_test.go | 42 +++++ internal/domain/peer_test.go | 165 ++++++++++++++++++ internal/domain/statistics_test.go | 74 ++++++++ internal/domain/user_test.go | 125 +++++++++++++ 40 files changed, 1337 insertions(+), 406 deletions(-) create mode 100644 internal/app/api/v0/backend/interface_service.go create mode 100644 internal/app/api/v0/backend/peer_service.go create mode 100644 internal/app/api/v0/backend/user_service.go delete mode 100644 internal/app/audit/repos.go delete mode 100644 internal/app/configfile/repos.go delete mode 100644 internal/app/mail/repos.go delete mode 100644 internal/app/repos.go delete mode 100644 internal/app/route/repos.go delete mode 100644 internal/app/users/repos.go delete mode 100644 internal/app/wireguard/repos.go create mode 100644 internal/domain/crypto_test.go create mode 100644 internal/domain/interface_test.go create mode 100644 internal/domain/options_test.go create mode 100644 internal/domain/peer_test.go create mode 100644 internal/domain/statistics_test.go create mode 100644 internal/domain/user_test.go diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index 18070e4..f388298 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -14,6 +14,7 @@ import ( "github.com/h44z/wg-portal/internal/adapters" "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/app/api/core" + backendV0 "github.com/h44z/wg-portal/internal/app/api/v0/backend" handlersV0 "github.com/h44z/wg-portal/internal/app/api/v0/handlers" backendV1 "github.com/h44z/wg-portal/internal/app/api/v1/backend" handlersV1 "github.com/h44z/wg-portal/internal/app/api/v1/handlers" @@ -70,17 +71,24 @@ func main() { queueSize := 100 eventBus := evbus.New(queueSize) + auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database) + internal.AssertNoError(err) + auditRecorder.StartBackgroundJobs(ctx) + userManager, err := users.NewUserManager(cfg, eventBus, database, database) internal.AssertNoError(err) + userManager.StartBackgroundJobs(ctx) authenticator, err := auth.NewAuthenticator(&cfg.Auth, cfg.Web.ExternalUrl, eventBus, userManager) internal.AssertNoError(err) wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database) internal.AssertNoError(err) + wireGuardManager.StartBackgroundJobs(ctx) statisticsCollector, err := wireguard.NewStatisticsCollector(cfg, eventBus, database, wireGuard, metricsServer) internal.AssertNoError(err) + statisticsCollector.StartBackgroundJobs(ctx) cfgFileManager, err := configfile.NewConfigFileManager(cfg, eventBus, database, database, cfgFileSystem) internal.AssertNoError(err) @@ -88,18 +96,11 @@ func main() { mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database) internal.AssertNoError(err) - auditRecorder, err := audit.NewAuditRecorder(cfg, eventBus, database) - internal.AssertNoError(err) - auditRecorder.StartBackgroundJobs(ctx) - routeManager, err := route.NewRouteManager(cfg, eventBus, database) internal.AssertNoError(err) routeManager.StartBackgroundJobs(ctx) - backend, err := app.New(cfg, eventBus, authenticator, userManager, wireGuardManager, - statisticsCollector, cfgFileManager, mailManager) - internal.AssertNoError(err) - err = backend.Startup(ctx) + err = app.Initialize(cfg, wireGuardManager, userManager) internal.AssertNoError(err) validatorManager := validator.New() @@ -109,10 +110,14 @@ func main() { apiV0Session := handlersV0.NewSessionWrapper(cfg) apiV0Auth := handlersV0.NewAuthenticationHandler(authenticator, apiV0Session) - apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, backend) - apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, backend) - apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, backend) - apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, backend) + apiV0BackendUsers := backendV0.NewUserService(cfg, userManager, wireGuardManager) + apiV0BackendInterfaces := backendV0.NewInterfaceService(cfg, wireGuardManager, cfgFileManager) + apiV0BackendPeers := backendV0.NewPeerService(cfg, wireGuardManager, cfgFileManager, mailManager) + + apiV0EndpointAuth := handlersV0.NewAuthEndpoint(cfg, apiV0Auth, apiV0Session, validatorManager, authenticator) + apiV0EndpointUsers := handlersV0.NewUserEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendUsers) + apiV0EndpointInterfaces := handlersV0.NewInterfaceEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendInterfaces) + apiV0EndpointPeers := handlersV0.NewPeerEndpoint(cfg, apiV0Auth, validatorManager, apiV0BackendPeers) apiV0EndpointConfig := handlersV0.NewConfigEndpoint(cfg, apiV0Auth) apiV0EndpointTest := handlersV0.NewTestEndpoint(apiV0Auth) diff --git a/internal/app/api/v0/backend/interface_service.go b/internal/app/api/v0/backend/interface_service.go new file mode 100644 index 0000000..fca1fe8 --- /dev/null +++ b/internal/app/api/v0/backend/interface_service.go @@ -0,0 +1,91 @@ +package backend + +import ( + "context" + "io" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +// region dependencies + +type InterfaceServiceInterfaceManager interface { + GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) + GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) + CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) + UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) + DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error + PrepareInterface(ctx context.Context) (*domain.Interface, error) + ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error +} + +type InterfaceServiceConfigFileManager interface { + PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error + GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) +} + +// endregion dependencies + +type InterfaceService struct { + cfg *config.Config + + interfaces InterfaceServiceInterfaceManager + configFile InterfaceServiceConfigFileManager +} + +func NewInterfaceService( + cfg *config.Config, + interfaces InterfaceServiceInterfaceManager, + configFile InterfaceServiceConfigFileManager, +) *InterfaceService { + return &InterfaceService{ + cfg: cfg, + interfaces: interfaces, + configFile: configFile, + } +} + +func (i InterfaceService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( + *domain.Interface, + []domain.Peer, + error, +) { + return i.interfaces.GetInterfaceAndPeers(ctx, id) +} + +func (i InterfaceService) PrepareInterface(ctx context.Context) (*domain.Interface, error) { + return i.interfaces.PrepareInterface(ctx) +} + +func (i InterfaceService) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) { + return i.interfaces.CreateInterface(ctx, in) +} + +func (i InterfaceService) UpdateInterface(ctx context.Context, in *domain.Interface) ( + *domain.Interface, + []domain.Peer, + error, +) { + return i.interfaces.UpdateInterface(ctx, in) +} + +func (i InterfaceService) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { + return i.interfaces.DeleteInterface(ctx, id) +} + +func (i InterfaceService) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) { + return i.interfaces.GetAllInterfacesAndPeers(ctx) +} + +func (i InterfaceService) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) { + return i.configFile.GetInterfaceConfig(ctx, id) +} + +func (i InterfaceService) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error { + return i.configFile.PersistInterfaceConfig(ctx, id) +} + +func (i InterfaceService) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error { + return i.interfaces.ApplyPeerDefaults(ctx, in) +} diff --git a/internal/app/api/v0/backend/peer_service.go b/internal/app/api/v0/backend/peer_service.go new file mode 100644 index 0000000..f590693 --- /dev/null +++ b/internal/app/api/v0/backend/peer_service.go @@ -0,0 +1,112 @@ +package backend + +import ( + "context" + "io" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +// region dependencies + +type PeerServicePeerManager interface { + GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) + GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) + GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) + PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) + CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) + UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) + DeletePeer(ctx context.Context, id domain.PeerIdentifier) error + CreateMultiplePeers( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + r *domain.PeerCreationRequest, + ) ([]domain.Peer, error) + GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) +} + +type PeerServiceConfigFileManager interface { + GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) + GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) +} + +type PeerServiceMailManager interface { + SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error +} + +// endregion dependencies + +type PeerService struct { + cfg *config.Config + + peers PeerServicePeerManager + configFile PeerServiceConfigFileManager + mailer PeerServiceMailManager +} + +func NewPeerService( + cfg *config.Config, + peers PeerServicePeerManager, + configFile PeerServiceConfigFileManager, + mailer PeerServiceMailManager, +) *PeerService { + return &PeerService{ + cfg: cfg, + peers: peers, + configFile: configFile, + mailer: mailer, + } +} + +func (p PeerService) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( + *domain.Interface, + []domain.Peer, + error, +) { + return p.peers.GetInterfaceAndPeers(ctx, id) +} + +func (p PeerService) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) { + return p.peers.PreparePeer(ctx, id) +} + +func (p PeerService) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) { + return p.peers.GetPeer(ctx, id) +} + +func (p PeerService) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { + return p.peers.CreatePeer(ctx, peer) +} + +func (p PeerService) CreateMultiplePeers( + ctx context.Context, + interfaceId domain.InterfaceIdentifier, + r *domain.PeerCreationRequest, +) ([]domain.Peer, error) { + return p.peers.CreateMultiplePeers(ctx, interfaceId, r) +} + +func (p PeerService) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { + return p.peers.UpdatePeer(ctx, peer) +} + +func (p PeerService) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { + return p.peers.DeletePeer(ctx, id) +} + +func (p PeerService) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) { + return p.configFile.GetPeerConfig(ctx, id) +} + +func (p PeerService) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) { + return p.configFile.GetPeerConfigQrCode(ctx, id) +} + +func (p PeerService) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error { + return p.mailer.SendPeerEmail(ctx, linkOnly, peers...) +} + +func (p PeerService) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) { + return p.peers.GetPeerStats(ctx, id) +} diff --git a/internal/app/api/v0/backend/user_service.go b/internal/app/api/v0/backend/user_service.go new file mode 100644 index 0000000..74fee34 --- /dev/null +++ b/internal/app/api/v0/backend/user_service.go @@ -0,0 +1,83 @@ +package backend + +import ( + "context" + + "github.com/h44z/wg-portal/internal/config" + "github.com/h44z/wg-portal/internal/domain" +) + +// region dependencies + +type UserServiceUserManager interface { + GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) + GetAllUsers(ctx context.Context) ([]domain.User, error) + CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) + UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) + DeleteUser(ctx context.Context, id domain.UserIdentifier) error + ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) + DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) +} + +type UserServiceWireGuardManager interface { + GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) + GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) ([]domain.Interface, error) + GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) +} + +// endregion dependencies + +type UserService struct { + cfg *config.Config + + users UserServiceUserManager + wg UserServiceWireGuardManager +} + +func NewUserService(cfg *config.Config, users UserServiceUserManager, wg UserServiceWireGuardManager) *UserService { + return &UserService{ + cfg: cfg, + users: users, + wg: wg, + } +} + +func (u UserService) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { + return u.users.GetUser(ctx, id) +} + +func (u UserService) GetAllUsers(ctx context.Context) ([]domain.User, error) { + return u.users.GetAllUsers(ctx) +} + +func (u UserService) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) { + return u.users.UpdateUser(ctx, user) +} + +func (u UserService) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) { + return u.users.CreateUser(ctx, user) +} + +func (u UserService) DeleteUser(ctx context.Context, id domain.UserIdentifier) error { + return u.users.DeleteUser(ctx, id) +} + +func (u UserService) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { + return u.users.ActivateApi(ctx, id) +} + +func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { + return u.users.DeactivateApi(ctx, id) +} + +func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { + return u.wg.GetUserPeers(ctx, id) +} + +func (u UserService) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) { + return u.wg.GetUserPeerStats(ctx, id) +} + +func (u UserService) GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error) { + return u.wg.GetUserInterfaces(ctx, id) +} diff --git a/internal/app/app.go b/internal/app/app.go index c9d596e..1eb24cb 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -7,46 +7,43 @@ import ( "log/slog" "time" - evbus "github.com/vardius/message-bus" - "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) -type App struct { - Config *config.Config - bus evbus.MessageBus +// region dependencies - Authenticator - UserManager - WireGuardManager - StatisticsCollector - ConfigFileManager - MailManager - ApiV1Manager +type WireGuardManager interface { + ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) + RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error } -func New( +type UserManager interface { + GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) + CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) +} + +// endregion dependencies + +// App is the main application struct. +type App struct { + cfg *config.Config + + wg WireGuardManager + users UserManager +} + +// Initialize creates a new App instance and initializes it. +func Initialize( cfg *config.Config, - bus evbus.MessageBus, - authenticator Authenticator, + wg WireGuardManager, users UserManager, - wireGuard WireGuardManager, - stats StatisticsCollector, - cfgFiles ConfigFileManager, - mailer MailManager, -) (*App, error) { - +) error { a := &App{ - Config: cfg, - bus: bus, + cfg: cfg, - Authenticator: authenticator, - UserManager: users, - WireGuardManager: wireGuard, - StatisticsCollector: stats, - ConfigFileManager: cfgFiles, - MailManager: mailer, + wg: wg, + users: users, } startupContext, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -56,36 +53,27 @@ func New( startupContext = domain.SetUserInfo(startupContext, domain.SystemAdminContextUserInfo()) if err := a.createDefaultUser(startupContext); err != nil { - return nil, fmt.Errorf("failed to create default user: %w", err) + return fmt.Errorf("failed to create default user: %w", err) } if err := a.importNewInterfaces(startupContext); err != nil { - return nil, fmt.Errorf("failed to import new interfaces: %w", err) + return fmt.Errorf("failed to import new interfaces: %w", err) } if err := a.restoreInterfaceState(startupContext); err != nil { - return nil, fmt.Errorf("failed to restore interface state: %w", err) + return fmt.Errorf("failed to restore interface state: %w", err) } - return a, nil -} - -func (a *App) Startup(ctx context.Context) error { - - a.UserManager.StartBackgroundJobs(ctx) - a.StatisticsCollector.StartBackgroundJobs(ctx) - a.WireGuardManager.StartBackgroundJobs(ctx) - return nil } func (a *App) importNewInterfaces(ctx context.Context) error { - if !a.Config.Core.ImportExisting { + if !a.cfg.Core.ImportExisting { slog.Debug("skipping interface import - feature disabled") return nil // feature disabled } - importedCount, err := a.ImportNewInterfaces(ctx) + importedCount, err := a.wg.ImportNewInterfaces(ctx) if err != nil { return err } @@ -97,12 +85,12 @@ func (a *App) importNewInterfaces(ctx context.Context) error { } func (a *App) restoreInterfaceState(ctx context.Context) error { - if !a.Config.Core.RestoreState { + if !a.cfg.Core.RestoreState { slog.Debug("skipping interface state restore - feature disabled") return nil // feature disabled } - err := a.RestoreInterfaceState(ctx, true) + err := a.wg.RestoreInterfaceState(ctx, true) if err != nil { return err } @@ -112,13 +100,13 @@ func (a *App) restoreInterfaceState(ctx context.Context) error { } func (a *App) createDefaultUser(ctx context.Context) error { - adminUserId := domain.UserIdentifier(a.Config.Core.AdminUser) + adminUserId := domain.UserIdentifier(a.cfg.Core.AdminUser) if adminUserId == "" { slog.Debug("skipping default user creation - admin user is blank") return nil // empty admin user - do not create } - _, err := a.GetUser(ctx, adminUserId) + _, err := a.users.GetUser(ctx, adminUserId) if err != nil && !errors.Is(err, domain.ErrNotFound) { return err } @@ -145,22 +133,22 @@ func (a *App) createDefaultUser(ctx context.Context) error { Phone: "", Department: "", Notes: "default administrator user", - Password: domain.PrivateString(a.Config.Core.AdminPassword), + Password: domain.PrivateString(a.cfg.Core.AdminPassword), Disabled: nil, DisabledReason: "", Locked: nil, LockedReason: "", LinkedPeerCount: 0, } - if a.Config.Core.AdminApiToken != "" { - if len(a.Config.Core.AdminApiToken) < 18 { + if a.cfg.Core.AdminApiToken != "" { + if len(a.cfg.Core.AdminApiToken) < 18 { slog.Warn("admin API token is too short, should be at least 18 characters long") } - defaultAdmin.ApiToken = a.Config.Core.AdminApiToken + defaultAdmin.ApiToken = a.cfg.Core.AdminApiToken defaultAdmin.ApiTokenCreated = &now } - admin, err := a.CreateUser(ctx, defaultAdmin) + admin, err := a.users.CreateUser(ctx, defaultAdmin) if err != nil { return err } diff --git a/internal/app/audit/recorder.go b/internal/app/audit/recorder.go index abcb4c3..79858e6 100644 --- a/internal/app/audit/recorder.go +++ b/internal/app/audit/recorder.go @@ -6,21 +6,35 @@ import ( "log/slog" "time" - evbus "github.com/vardius/message-bus" - "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) +// region dependencies + +type DatabaseRepo interface { + // SaveAuditEntry saves an audit entry to the database + SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error +} + +type EventBus interface { + // Subscribe subscribes to a topic + Subscribe(topic string, fn interface{}) error +} + +// endregion dependencies + +// Recorder is responsible for recording audit events to the database. type Recorder struct { cfg *config.Config - bus evbus.MessageBus + bus EventBus db DatabaseRepo } -func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo) (*Recorder, error) { +// NewAuditRecorder creates a new audit recorder instance. +func NewAuditRecorder(cfg *config.Config, bus EventBus, db DatabaseRepo) (*Recorder, error) { r := &Recorder{ cfg: cfg, bus: bus, @@ -36,6 +50,8 @@ func NewAuditRecorder(cfg *config.Config, bus evbus.MessageBus, db DatabaseRepo) return r, nil } +// StartBackgroundJobs starts background jobs for the audit recorder. +// This method is non-blocking and returns immediately. func (r *Recorder) StartBackgroundJobs(ctx context.Context) { if !r.cfg.Statistics.CollectAuditData { return // noting to do diff --git a/internal/app/audit/repos.go b/internal/app/audit/repos.go deleted file mode 100644 index d64b49b..0000000 --- a/internal/app/audit/repos.go +++ /dev/null @@ -1,11 +0,0 @@ -package audit - -import ( - "context" - - "github.com/h44z/wg-portal/internal/domain" -) - -type DatabaseRepo interface { - SaveAuditEntry(ctx context.Context, entry *domain.AuditEntry) error -} diff --git a/internal/app/auth/auth.go b/internal/app/auth/auth.go index d06b00f..f4d0f29 100644 --- a/internal/app/auth/auth.go +++ b/internal/app/auth/auth.go @@ -14,25 +14,78 @@ import ( "time" "github.com/coreos/go-oidc/v3/oidc" - evbus "github.com/vardius/message-bus" + "golang.org/x/oauth2" "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) +// region dependencies + type UserManager interface { + // GetUser returns a user by its identifier. GetUser(context.Context, domain.UserIdentifier) (*domain.User, error) + // RegisterUser creates a new user in the database. RegisterUser(ctx context.Context, user *domain.User) error + // UpdateUser updates an existing user in the database. UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) } +type EventBus interface { + // Publish sends a message to the message bus. + Publish(topic string, args ...any) +} + +// endregion dependencies + +type AuthenticatorType string + +const ( + AuthenticatorTypeOAuth AuthenticatorType = "oauth" + AuthenticatorTypeOidc AuthenticatorType = "oidc" +) + +// AuthenticatorOauth is the interface for all OAuth authenticators. +type AuthenticatorOauth interface { + // GetName returns the name of the authenticator. + GetName() string + // GetType returns the type of the authenticator. It can be either AuthenticatorTypeOAuth or AuthenticatorTypeOidc. + GetType() AuthenticatorType + // AuthCodeURL returns the URL for the authentication flow. + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + // Exchange exchanges the OAuth code for an access token. + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + // GetUserInfo fetches the user information from the OAuth or OIDC provider. + GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error) + // ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct. + ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) + // RegistrationEnabled returns whether registration is enabled for the OAuth authenticator. + RegistrationEnabled() bool +} + +// AuthenticatorLdap is the interface for all LDAP authenticators. +type AuthenticatorLdap interface { + // GetName returns the name of the authenticator. + GetName() string + // PlaintextAuthentication performs a plaintext authentication against the LDAP server. + PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error + // GetUserInfo fetches the user information from the LDAP server. + GetUserInfo(ctx context.Context, username domain.UserIdentifier) (map[string]any, error) + // ParseUserInfo parses the raw user information into a domain.AuthenticatorUserInfo struct. + ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) + // RegistrationEnabled returns whether registration is enabled for the LDAP authenticator. + RegistrationEnabled() bool +} + +// Authenticator is the main entry point for all authentication related tasks. +// This includes password authentication and external authentication providers (OIDC, OAuth, LDAP). type Authenticator struct { cfg *config.Auth - bus evbus.MessageBus + bus EventBus - oauthAuthenticators map[string]domain.OauthAuthenticator - ldapAuthenticators map[string]domain.LdapAuthenticator + oauthAuthenticators map[string]AuthenticatorOauth + ldapAuthenticators map[string]AuthenticatorLdap // URL prefix for the callback endpoints, this is a combination of the external URL and the API prefix callbackUrlPrefix string @@ -40,7 +93,8 @@ type Authenticator struct { users UserManager } -func NewAuthenticator(cfg *config.Auth, extUrl string, bus evbus.MessageBus, users UserManager) ( +// NewAuthenticator creates a new Authenticator instance. +func NewAuthenticator(cfg *config.Auth, extUrl string, bus EventBus, users UserManager) ( *Authenticator, error, ) { @@ -68,8 +122,8 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error { return fmt.Errorf("failed to parse external url: %w", err) } - a.oauthAuthenticators = make(map[string]domain.OauthAuthenticator, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth)) - a.ldapAuthenticators = make(map[string]domain.LdapAuthenticator, len(a.cfg.Ldap)) + a.oauthAuthenticators = make(map[string]AuthenticatorOauth, len(a.cfg.OpenIDConnect)+len(a.cfg.OAuth)) + a.ldapAuthenticators = make(map[string]AuthenticatorLdap, len(a.cfg.Ldap)) for i := range a.cfg.OpenIDConnect { // OIDC providerCfg := &a.cfg.OpenIDConnect[i] @@ -123,6 +177,7 @@ func (a *Authenticator) setupExternalAuthProviders(ctx context.Context) error { return nil } +// GetExternalLoginProviders returns a list of all available external login providers. func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo { authProviders := make([]domain.LoginProviderInfo, 0, len(a.cfg.OAuth)+len(a.cfg.OpenIDConnect)) @@ -157,6 +212,7 @@ func (a *Authenticator) GetExternalLoginProviders(_ context.Context) []domain.Lo return authProviders } +// IsUserValid checks if a user is valid and not locked or disabled. func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifier) bool { ctx = domain.SetUserInfo(ctx, domain.SystemAdminContextUserInfo()) // switch to admin user context user, err := a.users.GetUser(ctx, id) @@ -177,6 +233,8 @@ func (a *Authenticator) IsUserValid(ctx context.Context, id domain.UserIdentifie // region password authentication +// PlainLogin performs a password authentication for a user. The username and password are trimmed before usage. +// If the login is successful, the user is returned, otherwise an error. func (a *Authenticator) PlainLogin(ctx context.Context, username, password string) (*domain.User, error) { // Validate form input username = strings.TrimSpace(username) @@ -204,7 +262,7 @@ func (a *Authenticator) passwordAuthentication( domain.SystemAdminContextUserInfo()) // switch to admin user context to check if user exists var ldapUserInfo *domain.AuthenticatorUserInfo - var ldapProvider domain.LdapAuthenticator + var ldapProvider AuthenticatorLdap var userInDatabase = false var userSource domain.UserSource @@ -280,6 +338,7 @@ func (a *Authenticator) passwordAuthentication( // region oauth authentication +// OauthLoginStep1 starts the oauth authentication flow by returning the authentication URL, state and nonce. func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) ( authCodeUrl, state, nonce string, err error, @@ -296,9 +355,9 @@ func (a *Authenticator) OauthLoginStep1(_ context.Context, providerId string) ( } switch oauthProvider.GetType() { - case domain.AuthenticatorTypeOAuth: + case AuthenticatorTypeOAuth: authCodeUrl = oauthProvider.AuthCodeURL(state) - case domain.AuthenticatorTypeOidc: + case AuthenticatorTypeOidc: nonce, err = a.randString(16) if err != nil { return "", "", "", fmt.Errorf("failed to generate nonce: %w", err) @@ -318,6 +377,8 @@ func (a *Authenticator) randString(nByte int) (string, error) { return base64.RawURLEncoding.EncodeToString(b), nil } +// OauthLoginStep2 finishes the oauth authentication flow by exchanging the code for an access token and +// fetching the user information. func (a *Authenticator) OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) { oauthProvider, ok := a.oauthAuthenticators[providerId] if !ok { diff --git a/internal/app/auth/auth_ldap.go b/internal/app/auth/auth_ldap.go index 3fa2fe4..64122f2 100644 --- a/internal/app/auth/auth_ldap.go +++ b/internal/app/auth/auth_ldap.go @@ -14,6 +14,7 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +// LdapAuthenticator is an authenticator that uses LDAP for authentication. type LdapAuthenticator struct { cfg *config.LdapProvider } @@ -33,14 +34,17 @@ func newLdapAuthenticator(_ context.Context, cfg *config.LdapProvider) (*LdapAut return provider, nil } +// GetName returns the name of the LDAP authenticator. func (l LdapAuthenticator) GetName() string { return l.cfg.ProviderName } +// RegistrationEnabled returns whether registration is enabled for the LDAP authenticator. func (l LdapAuthenticator) RegistrationEnabled() bool { return l.cfg.RegistrationEnabled } +// PlaintextAuthentication performs a plaintext authentication against the LDAP server. func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, plainPassword string) error { conn, err := internal.LdapConnect(l.cfg) if err != nil { @@ -81,6 +85,9 @@ func (l LdapAuthenticator) PlaintextAuthentication(userId domain.UserIdentifier, return nil } +// GetUserInfo retrieves user information from the LDAP server. +// If the user is not found, domain.ErrNotFound is returned. +// If multiple users are found, domain.ErrNotUnique is returned. func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIdentifier) ( map[string]any, error, @@ -126,6 +133,7 @@ func (l LdapAuthenticator) GetUserInfo(_ context.Context, userId domain.UserIden return users[0], nil } +// ParseUserInfo parses the user information from the LDAP server into a domain.AuthenticatorUserInfo struct. func (l LdapAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { isAdmin, err := internal.LdapIsMemberOf(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.ParsedAdminGroupDN) if err != nil { diff --git a/internal/app/auth/auth_oauth.go b/internal/app/auth/auth_oauth.go index e2d141f..7e730bc 100644 --- a/internal/app/auth/auth_oauth.go +++ b/internal/app/auth/auth_oauth.go @@ -16,6 +16,8 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication. +// User information is retrieved from the specified user info endpoint. type PlainOauthAuthenticator struct { name string cfg *oauth2.Config @@ -58,22 +60,27 @@ func newPlainOauthAuthenticator( return provider, nil } +// GetName returns the name of the OAuth authenticator. func (p PlainOauthAuthenticator) GetName() string { return p.name } +// RegistrationEnabled returns whether registration is enabled for the OAuth authenticator. func (p PlainOauthAuthenticator) RegistrationEnabled() bool { return p.registrationEnabled } -func (p PlainOauthAuthenticator) GetType() domain.AuthenticatorType { - return domain.AuthenticatorTypeOAuth +// GetType returns the type of the authenticator. +func (p PlainOauthAuthenticator) GetType() AuthenticatorType { + return AuthenticatorTypeOAuth } +// AuthCodeURL returns the URL to redirect the user to for authentication. func (p PlainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { return p.cfg.AuthCodeURL(state, opts...) } +// Exchange exchanges the OAuth code for a token. func (p PlainOauthAuthenticator) Exchange( ctx context.Context, code string, @@ -82,6 +89,7 @@ func (p PlainOauthAuthenticator) Exchange( return p.cfg.Exchange(ctx, code, opts...) } +// GetUserInfo retrieves the user information from the user info endpoint. func (p PlainOauthAuthenticator) GetUserInfo( ctx context.Context, token *oauth2.Token, @@ -119,6 +127,7 @@ func (p PlainOauthAuthenticator) GetUserInfo( return userFields, nil } +// ParseUserInfo parses the user information from the raw data. func (p PlainOauthAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { return parseOauthUserInfo(p.userInfoMapping, p.userAdminMapping, raw) } diff --git a/internal/app/auth/auth_oidc.go b/internal/app/auth/auth_oidc.go index d213959..d832768 100644 --- a/internal/app/auth/auth_oidc.go +++ b/internal/app/auth/auth_oidc.go @@ -14,6 +14,7 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +// OidcAuthenticator is an authenticator for OpenID Connect providers. type OidcAuthenticator struct { name string provider *oidc.Provider @@ -60,22 +61,27 @@ func newOidcAuthenticator( return provider, nil } +// GetName returns the name of the authenticator. func (o OidcAuthenticator) GetName() string { return o.name } +// RegistrationEnabled returns whether registration is enabled for this authenticator. func (o OidcAuthenticator) RegistrationEnabled() bool { return o.registrationEnabled } -func (o OidcAuthenticator) GetType() domain.AuthenticatorType { - return domain.AuthenticatorTypeOidc +// GetType returns the type of the authenticator. +func (o OidcAuthenticator) GetType() AuthenticatorType { + return AuthenticatorTypeOidc } +// AuthCodeURL returns the URL for the OAuth2 flow. func (o OidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { return o.cfg.AuthCodeURL(state, opts...) } +// Exchange exchanges the code for a token. func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) ( *oauth2.Token, error, @@ -83,6 +89,7 @@ func (o OidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oa return o.cfg.Exchange(ctx, code, opts...) } +// GetUserInfo retrieves the user info from the token. func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) ( map[string]any, error, @@ -114,6 +121,7 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, return tokenFields, nil } +// ParseUserInfo parses the user info. func (o OidcAuthenticator) ParseUserInfo(raw map[string]any) (*domain.AuthenticatorUserInfo, error) { return parseOauthUserInfo(o.userInfoMapping, o.userAdminMapping, raw) } diff --git a/internal/app/configfile/manager.go b/internal/app/configfile/manager.go index 9f9a4f7..6e859b4 100644 --- a/internal/app/configfile/manager.go +++ b/internal/app/configfile/manager.go @@ -10,7 +10,6 @@ import ( "os" "strings" - evbus "github.com/vardius/message-bus" "github.com/yeqown/go-qrcode/v2" "github.com/yeqown/go-qrcode/writer/compressed" @@ -19,19 +18,56 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) -type Manager struct { - cfg *config.Config - bus evbus.MessageBus - tplHandler *TemplateHandler +// region dependencies - fsRepo FileSystemRepo - users UserDatabaseRepo - wg WireguardDatabaseRepo +type UserDatabaseRepo interface { + // GetUser returns the user with the given identifier from the SQL database. + GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) } +type WireguardDatabaseRepo interface { + // GetInterfaceAndPeers returns the interface and all peers associated with it. + GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) + // GetPeer returns the peer with the given identifier. + GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) + // GetInterface returns the interface with the given identifier. + GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) +} + +type FileSystemRepo interface { + // WriteFile writes the contents to the file at the given path. + WriteFile(path string, contents io.Reader) error +} + +type TemplateRenderer interface { + // GetInterfaceConfig returns the configuration file for the given interface. + GetInterfaceConfig(iface *domain.Interface, peers []domain.Peer) (io.Reader, error) + // GetPeerConfig returns the configuration file for the given peer. + GetPeerConfig(peer *domain.Peer) (io.Reader, error) +} + +type EventBus interface { + // Subscribe subscribes to the given topic. + Subscribe(topic string, fn any) error +} + +// endregion dependencies + +// Manager is responsible for managing the configuration files of the WireGuard interfaces and peers. +type Manager struct { + cfg *config.Config + bus EventBus + + tplHandler TemplateRenderer + fsRepo FileSystemRepo + users UserDatabaseRepo + wg WireguardDatabaseRepo +} + +// NewConfigFileManager creates a new Manager instance. func NewConfigFileManager( cfg *config.Config, - bus evbus.MessageBus, + bus EventBus, users UserDatabaseRepo, wg WireguardDatabaseRepo, fsRepo FileSystemRepo, @@ -115,6 +151,8 @@ func (m Manager) handlePeerInterfaceUpdatedEvent(id domain.InterfaceIdentifier) } } +// GetInterfaceConfig returns the configuration file for the given interface. +// The file is structured in wg-quick format. func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -128,6 +166,8 @@ func (m Manager) GetInterfaceConfig(ctx context.Context, id domain.InterfaceIden return m.tplHandler.GetInterfaceConfig(iface, peers) } +// GetPeerConfig returns the configuration file for the given peer. +// The file is structured in wg-quick format. func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) { peer, err := m.wg.GetPeer(ctx, id) if err != nil { @@ -141,6 +181,7 @@ func (m Manager) GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (i return m.tplHandler.GetPeerConfig(peer) } +// GetPeerConfigQrCode returns a QR code image containing the configuration for the given peer. func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) { peer, err := m.wg.GetPeer(ctx, id) if err != nil { @@ -191,6 +232,7 @@ func (m Manager) GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifi return buf, nil } +// PersistInterfaceConfig writes the configuration file for the given interface to the file system. func (m Manager) PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error { iface, peers, err := m.wg.GetInterfaceAndPeers(ctx, id) if err != nil { @@ -213,4 +255,5 @@ type nopCloser struct { io.Writer } +// Close is a no-op for the nopCloser. func (nopCloser) Close() error { return nil } diff --git a/internal/app/configfile/repos.go b/internal/app/configfile/repos.go deleted file mode 100644 index c7c9b30..0000000 --- a/internal/app/configfile/repos.go +++ /dev/null @@ -1,22 +0,0 @@ -package configfile - -import ( - "context" - "io" - - "github.com/h44z/wg-portal/internal/domain" -) - -type UserDatabaseRepo interface { - GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) -} - -type WireguardDatabaseRepo interface { - GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) - GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) -} - -type FileSystemRepo interface { - WriteFile(path string, contents io.Reader) error -} diff --git a/internal/app/configfile/template.go b/internal/app/configfile/template.go index 5d2bdde..ed40d9b 100644 --- a/internal/app/configfile/template.go +++ b/internal/app/configfile/template.go @@ -13,6 +13,8 @@ import ( //go:embed tpl_files/* var TemplateFiles embed.FS +// TemplateHandler is responsible for rendering the WireGuard configuration files +// based on the provided templates. type TemplateHandler struct { templates *template.Template } @@ -34,6 +36,7 @@ func newTemplateHandler() (*TemplateHandler, error) { return handler, nil } +// GetInterfaceConfig returns the rendered configuration file for a WireGuard interface. func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domain.Peer) (io.Reader, error) { var tplBuff bytes.Buffer @@ -51,6 +54,7 @@ func (c TemplateHandler) GetInterfaceConfig(cfg *domain.Interface, peers []domai return &tplBuff, nil } +// GetPeerConfig returns the rendered configuration file for a WireGuard peer. func (c TemplateHandler) GetPeerConfig(peer *domain.Peer) (io.Reader, error) { var tplBuff bytes.Buffer diff --git a/internal/app/mail/manager.go b/internal/app/mail/manager.go index fb3e28c..7b31406 100644 --- a/internal/app/mail/manager.go +++ b/internal/app/mail/manager.go @@ -10,16 +10,60 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) -type Manager struct { - cfg *config.Config - tplHandler *TemplateHandler +// region dependencies +type Mailer interface { + // Send sends an email with the given subject and body to the given recipients. + Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error +} + +type ConfigFileManager interface { + // GetInterfaceConfig returns the configuration for the given interface. + GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) + // GetPeerConfig returns the configuration for the given peer. + GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) + // GetPeerConfigQrCode returns the QR code for the given peer. + GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) +} + +type UserDatabaseRepo interface { + // GetUser returns the user with the given identifier. + GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) +} + +type WireguardDatabaseRepo interface { + // GetInterfaceAndPeers returns the interface and all peers for the given interface identifier. + GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) + // GetPeer returns the peer with the given identifier. + GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) + // GetInterface returns the interface with the given identifier. + GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) +} + +type TemplateRenderer interface { + // GetConfigMail returns the text and html template for the mail with a link. + GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) + // GetConfigMailWithAttachment returns the text and html template for the mail with an attachment. + GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) ( + io.Reader, + io.Reader, + error, + ) +} + +// endregion dependencies + +type Manager struct { + cfg *config.Config + + tplHandler TemplateRenderer mailer Mailer configFiles ConfigFileManager users UserDatabaseRepo wg WireguardDatabaseRepo } +// NewMailManager creates a new mail manager. func NewMailManager( cfg *config.Config, mailer Mailer, @@ -44,6 +88,7 @@ func NewMailManager( return m, nil } +// SendPeerEmail sends an email to the user linked to the given peers. func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error { for _, peerId := range peers { peer, err := m.wg.GetPeer(ctx, peerId) diff --git a/internal/app/mail/repos.go b/internal/app/mail/repos.go deleted file mode 100644 index 79b3ea9..0000000 --- a/internal/app/mail/repos.go +++ /dev/null @@ -1,28 +0,0 @@ -package mail - -import ( - "context" - "io" - - "github.com/h44z/wg-portal/internal/domain" -) - -type Mailer interface { - Send(ctx context.Context, subject, body string, to []string, options *domain.MailOptions) error -} - -type ConfigFileManager interface { - GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) - GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) - GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) -} - -type UserDatabaseRepo interface { - GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) -} - -type WireguardDatabaseRepo interface { - GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) - GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) -} diff --git a/internal/app/mail/template.go b/internal/app/mail/template.go index 3aea4de..15253bf 100644 --- a/internal/app/mail/template.go +++ b/internal/app/mail/template.go @@ -14,6 +14,7 @@ import ( //go:embed tpl_files/* var TemplateFiles embed.FS +// TemplateHandler is a struct that holds the html and text templates. type TemplateHandler struct { portalUrl string htmlTemplates *htmlTemplate.Template @@ -40,6 +41,7 @@ func newTemplateHandler(portalUrl string) (*TemplateHandler, error) { return handler, nil } +// GetConfigMail returns the text and html template for the mail with a link. func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reader, io.Reader, error) { var tplBuff bytes.Buffer var htmlTplBuff bytes.Buffer @@ -65,6 +67,7 @@ func (c TemplateHandler) GetConfigMail(user *domain.User, link string) (io.Reade return &tplBuff, &htmlTplBuff, nil } +// GetConfigMailWithAttachment returns the text and html template for the mail with an attachment. func (c TemplateHandler) GetConfigMailWithAttachment(user *domain.User, cfgName, qrName string) ( io.Reader, io.Reader, diff --git a/internal/app/repos.go b/internal/app/repos.go deleted file mode 100644 index 85fae2a..0000000 --- a/internal/app/repos.go +++ /dev/null @@ -1,77 +0,0 @@ -package app - -import ( - "context" - "io" - - "github.com/h44z/wg-portal/internal/domain" -) - -type Authenticator interface { - GetExternalLoginProviders(_ context.Context) []domain.LoginProviderInfo - IsUserValid(ctx context.Context, id domain.UserIdentifier) bool - PlainLogin(ctx context.Context, username, password string) (*domain.User, error) - OauthLoginStep1(_ context.Context, providerId string) (authCodeUrl, state, nonce string, err error) - OauthLoginStep2(ctx context.Context, providerId, nonce, code string) (*domain.User, error) -} - -type UserManager interface { - RegisterUser(ctx context.Context, user *domain.User) error - NewUser(ctx context.Context, user *domain.User) error - StartBackgroundJobs(ctx context.Context) - GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) - GetAllUsers(ctx context.Context) ([]domain.User, error) - UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) - CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) - DeleteUser(ctx context.Context, id domain.UserIdentifier) error - ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) - DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) -} - -type WireGuardManager interface { - StartBackgroundJobs(ctx context.Context) - GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) - ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) - RestoreInterfaceState(ctx context.Context, updateDbOnError bool, filter ...domain.InterfaceIdentifier) error - CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error - GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) - GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) - GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) - GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) - GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) - GetUserInterfaces(ctx context.Context, id domain.UserIdentifier) ([]domain.Interface, error) - PrepareInterface(ctx context.Context) (*domain.Interface, error) - CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) - UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) - DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error - PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) - GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - CreatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error) - CreateMultiplePeers( - ctx context.Context, - id domain.InterfaceIdentifier, - r *domain.PeerCreationRequest, - ) ([]domain.Peer, error) - UpdatePeer(ctx context.Context, p *domain.Peer) (*domain.Peer, error) - DeletePeer(ctx context.Context, id domain.PeerIdentifier) error - ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error -} - -type StatisticsCollector interface { - StartBackgroundJobs(ctx context.Context) -} - -type ConfigFileManager interface { - GetInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) (io.Reader, error) - GetPeerConfig(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) - GetPeerConfigQrCode(ctx context.Context, id domain.PeerIdentifier) (io.Reader, error) - PersistInterfaceConfig(ctx context.Context, id domain.InterfaceIdentifier) error -} - -type MailManager interface { - SendPeerEmail(ctx context.Context, linkOnly bool, peers ...domain.PeerIdentifier) error -} - -type ApiV1Manager interface { - ApiV1GetUsers(ctx context.Context) ([]domain.User, error) -} diff --git a/internal/app/route/repos.go b/internal/app/route/repos.go deleted file mode 100644 index e9eb533..0000000 --- a/internal/app/route/repos.go +++ /dev/null @@ -1,12 +0,0 @@ -package route - -import ( - "context" - - "github.com/h44z/wg-portal/internal/domain" -) - -type InterfaceAndPeerDatabaseRepo interface { - GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) - GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) -} diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go index 864f9a1..c87bcaf 100644 --- a/internal/app/route/routes.go +++ b/internal/app/route/routes.go @@ -5,7 +5,6 @@ import ( "fmt" "log/slog" - evbus "github.com/vardius/message-bus" "github.com/vishvananda/netlink" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/wgctrl" @@ -17,6 +16,22 @@ import ( "github.com/h44z/wg-portal/internal/lowlevel" ) +// region dependencies + +type InterfaceAndPeerDatabaseRepo interface { + // GetAllInterfaces returns all interfaces + GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) + // GetInterfacePeers returns all peers for a given interface + GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) +} + +type EventBus interface { + // Subscribe subscribes to a topic + Subscribe(topic string, fn interface{}) error +} + +// endregion dependencies + type routeRuleInfo struct { ifaceId domain.InterfaceIdentifier fwMark uint32 @@ -29,14 +44,15 @@ type routeRuleInfo struct { // for default routes. type Manager struct { cfg *config.Config - bus evbus.MessageBus - wg lowlevel.WireGuardClient - nl lowlevel.NetlinkClient - db InterfaceAndPeerDatabaseRepo + bus EventBus + wg lowlevel.WireGuardClient + nl lowlevel.NetlinkClient + db InterfaceAndPeerDatabaseRepo } -func NewRouteManager(cfg *config.Config, bus evbus.MessageBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { +// NewRouteManager creates a new route manager instance. +func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) { wg, err := wgctrl.New() if err != nil { panic("failed to init wgctrl: " + err.Error()) @@ -63,7 +79,10 @@ func (m Manager) connectToMessageBus() { _ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent) } +// StartBackgroundJobs starts background jobs for the route manager. +// This method is non-blocking and returns immediately. func (m Manager) StartBackgroundJobs(_ context.Context) { + // this is a no-op for now } func (m Manager) handleRouteUpdateEvent(srcDescription string) { diff --git a/internal/app/users/repos.go b/internal/app/users/repos.go deleted file mode 100644 index 967b429..0000000 --- a/internal/app/users/repos.go +++ /dev/null @@ -1,20 +0,0 @@ -package users - -import ( - "context" - - "github.com/h44z/wg-portal/internal/domain" -) - -type UserDatabaseRepo interface { - GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) - GetUserByEmail(ctx context.Context, email string) (*domain.User, error) - GetAllUsers(ctx context.Context) ([]domain.User, error) - FindUsers(ctx context.Context, search string) ([]domain.User, error) - SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error - DeleteUser(ctx context.Context, id domain.UserIdentifier) error -} - -type PeerDatabaseRepo interface { - GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) -} diff --git a/internal/app/users/user_manager.go b/internal/app/users/user_manager.go index 85c50cc..180eca2 100644 --- a/internal/app/users/user_manager.go +++ b/internal/app/users/user_manager.go @@ -11,7 +11,6 @@ import ( "github.com/go-ldap/ldap/v3" "github.com/google/uuid" - evbus "github.com/vardius/message-bus" "github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal/app" @@ -19,15 +18,46 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +// region dependencies + +type UserDatabaseRepo interface { + // GetUser returns the user with the given identifier. + GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) + // GetUserByEmail returns the user with the given email address. + GetUserByEmail(ctx context.Context, email string) (*domain.User, error) + // GetAllUsers returns all users. + GetAllUsers(ctx context.Context) ([]domain.User, error) + // FindUsers returns all users matching the search string. + FindUsers(ctx context.Context, search string) ([]domain.User, error) + // SaveUser saves the user with the given identifier. + SaveUser(ctx context.Context, id domain.UserIdentifier, updateFunc func(u *domain.User) (*domain.User, error)) error + // DeleteUser deletes the user with the given identifier. + DeleteUser(ctx context.Context, id domain.UserIdentifier) error +} + +type PeerDatabaseRepo interface { + // GetUserPeers returns all peers linked to the given user. + GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) +} + +type EventBus interface { + // Publish sends a message to the message bus. + Publish(topic string, args ...any) +} + +// endregion dependencies + +// Manager is the user manager. type Manager struct { cfg *config.Config - bus evbus.MessageBus + bus EventBus users UserDatabaseRepo peers PeerDatabaseRepo } -func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabaseRepo, peers PeerDatabaseRepo) ( +// NewUserManager creates a new user manager instance. +func NewUserManager(cfg *config.Config, bus EventBus, users UserDatabaseRepo, peers PeerDatabaseRepo) ( *Manager, error, ) { @@ -41,6 +71,7 @@ func NewUserManager(cfg *config.Config, bus evbus.MessageBus, users UserDatabase return m, nil } +// RegisterUser registers a new user. func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return err @@ -56,6 +87,7 @@ func (m Manager) RegisterUser(ctx context.Context, user *domain.User) error { return nil } +// NewUser creates a new user. func (m Manager) NewUser(ctx context.Context, user *domain.User) error { if user.Identifier == "" { return errors.New("missing user identifier") @@ -90,12 +122,13 @@ func (m Manager) NewUser(ctx context.Context, user *domain.User) error { return nil } +// StartBackgroundJobs starts the background jobs. +// This method is non-blocking and returns immediately. func (m Manager) StartBackgroundJobs(ctx context.Context) { - go m.runLdapSynchronizationService(ctx) - } +// GetUser returns the user with the given identifier. func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { if err := domain.ValidateUserAccessRights(ctx, id); err != nil { return nil, err @@ -112,6 +145,7 @@ func (m Manager) GetUser(ctx context.Context, id domain.UserIdentifier) (*domain return user, nil } +// GetUserByEmail returns the user with the given email address. func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) { user, err := m.users.GetUserByEmail(ctx, email) @@ -130,6 +164,7 @@ func (m Manager) GetUserByEmail(ctx context.Context, email string) (*domain.User return user, nil } +// GetAllUsers returns all users. func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -162,6 +197,7 @@ func (m Manager) GetAllUsers(ctx context.Context) ([]domain.User, error) { return users, nil } +// UpdateUser updates the user with the given identifier. func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.User, error) { if err := domain.ValidateUserAccessRights(ctx, user.Identifier); err != nil { return nil, err @@ -203,6 +239,7 @@ func (m Manager) UpdateUser(ctx context.Context, user *domain.User) (*domain.Use return user, nil } +// CreateUser creates a new user. func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.User, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -236,6 +273,7 @@ func (m Manager) CreateUser(ctx context.Context, user *domain.User) (*domain.Use return user, nil } +// DeleteUser deletes the user with the given identifier. func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return err @@ -260,6 +298,7 @@ func (m Manager) DeleteUser(ctx context.Context, id domain.UserIdentifier) error return nil } +// ActivateApi activates the API access for the user with the given identifier. func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { user, err := m.users.GetUser(ctx, id) if err != nil && !errors.Is(err, domain.ErrNotFound) { @@ -287,6 +326,7 @@ func (m Manager) ActivateApi(ctx context.Context, id domain.UserIdentifier) (*do return user, nil } +// DeactivateApi deactivates the API access for the user with the given identifier. func (m Manager) DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error) { user, err := m.users.GetUser(ctx, id) if err != nil && !errors.Is(err, domain.ErrNotFound) { diff --git a/internal/app/wireguard/repos.go b/internal/app/wireguard/repos.go deleted file mode 100644 index 67d88e1..0000000 --- a/internal/app/wireguard/repos.go +++ /dev/null @@ -1,87 +0,0 @@ -package wireguard - -import ( - "context" - - "github.com/h44z/wg-portal/internal/domain" -) - -type InterfaceAndPeerDatabaseRepo interface { - GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) - GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) - GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) - GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) - FindInterfaces(ctx context.Context, search string) ([]domain.Interface, error) - GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) - SaveInterface( - ctx context.Context, - id domain.InterfaceIdentifier, - updateFunc func(in *domain.Interface) (*domain.Interface, error), - ) error - DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error - GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) - FindInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier, search string) ([]domain.Peer, error) - GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) - FindUserPeers(ctx context.Context, id domain.UserIdentifier, search string) ([]domain.Peer, error) - SavePeer( - ctx context.Context, - id domain.PeerIdentifier, - updateFunc func(in *domain.Peer) (*domain.Peer, error), - ) error - DeletePeer(ctx context.Context, id domain.PeerIdentifier) error - GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) -} - -type StatisticsDatabaseRepo interface { - GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) - GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) - GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) - - UpdatePeerStatus( - ctx context.Context, - id domain.PeerIdentifier, - updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error), - ) error - UpdateInterfaceStatus( - ctx context.Context, - id domain.InterfaceIdentifier, - updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error), - ) error - - DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error -} - -type InterfaceController interface { - GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) - GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) - GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) - GetPeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) ( - *domain.PhysicalPeer, - error, - ) - SaveInterface( - _ context.Context, - id domain.InterfaceIdentifier, - updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), - ) error - DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error - SavePeer( - _ context.Context, - deviceId domain.InterfaceIdentifier, - id domain.PeerIdentifier, - updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), - ) error - DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error -} - -type WgQuickController interface { - ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error - SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error - UnsetDNS(id domain.InterfaceIdentifier) error -} - -type MetricsServer interface { - UpdateInterfaceMetrics(status domain.InterfaceStatus) - UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus) -} diff --git a/internal/app/wireguard/statistics.go b/internal/app/wireguard/statistics.go index ed9bf75..ffa7571 100644 --- a/internal/app/wireguard/statistics.go +++ b/internal/app/wireguard/statistics.go @@ -7,31 +7,63 @@ import ( "time" probing "github.com/prometheus-community/pro-bing" - evbus "github.com/vardius/message-bus" "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) +type StatisticsDatabaseRepo interface { + GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) + GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) + GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) + UpdatePeerStatus( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.PeerStatus) (*domain.PeerStatus, error), + ) error + UpdateInterfaceStatus( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.InterfaceStatus) (*domain.InterfaceStatus, error), + ) error + DeletePeerStatus(ctx context.Context, id domain.PeerIdentifier) error +} + +type StatisticsInterfaceController interface { + GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) + GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) +} + +type StatisticsMetricsServer interface { + UpdateInterfaceMetrics(status domain.InterfaceStatus) + UpdatePeerMetrics(peer *domain.Peer, status domain.PeerStatus) +} + +type StatisticsEventBus interface { + // Subscribe subscribes to a topic + Subscribe(topic string, fn interface{}) error +} + type StatisticsCollector struct { cfg *config.Config - bus evbus.MessageBus + bus StatisticsEventBus pingWaitGroup sync.WaitGroup pingJobs chan domain.Peer db StatisticsDatabaseRepo - wg InterfaceController - ms MetricsServer + wg StatisticsInterfaceController + ms StatisticsMetricsServer } +// NewStatisticsCollector creates a new statistics collector. func NewStatisticsCollector( cfg *config.Config, - bus evbus.MessageBus, + bus StatisticsEventBus, db StatisticsDatabaseRepo, - wg InterfaceController, - ms MetricsServer, + wg StatisticsInterfaceController, + ms StatisticsMetricsServer, ) (*StatisticsCollector, error) { c := &StatisticsCollector{ cfg: cfg, @@ -47,6 +79,8 @@ func NewStatisticsCollector( return c, nil } +// StartBackgroundJobs starts the background jobs for the statistics collector. +// This method is non-blocking and returns immediately. func (c *StatisticsCollector) StartBackgroundJobs(ctx context.Context) { c.startPingWorkers(ctx) c.startInterfaceDataFetcher(ctx) diff --git a/internal/app/wireguard/wireguard.go b/internal/app/wireguard/wireguard.go index 82fd55a..e0673e3 100644 --- a/internal/app/wireguard/wireguard.go +++ b/internal/app/wireguard/wireguard.go @@ -5,17 +5,74 @@ import ( "log/slog" "time" - evbus "github.com/vardius/message-bus" - "github.com/h44z/wg-portal/internal/app" "github.com/h44z/wg-portal/internal/config" "github.com/h44z/wg-portal/internal/domain" ) -type Manager struct { - cfg *config.Config - bus evbus.MessageBus +// region dependencies +type InterfaceAndPeerDatabaseRepo interface { + GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error) + GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, []domain.Peer, error) + GetPeersStats(ctx context.Context, ids ...domain.PeerIdentifier) ([]domain.PeerStatus, error) + GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) + GetInterfaceIps(ctx context.Context) (map[domain.InterfaceIdentifier][]domain.Cidr, error) + SaveInterface( + ctx context.Context, + id domain.InterfaceIdentifier, + updateFunc func(in *domain.Interface) (*domain.Interface, error), + ) error + DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error + GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error) + GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) + SavePeer( + ctx context.Context, + id domain.PeerIdentifier, + updateFunc func(in *domain.Peer) (*domain.Peer, error), + ) error + DeletePeer(ctx context.Context, id domain.PeerIdentifier) error + GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) + GetUsedIpsPerSubnet(ctx context.Context, subnets []domain.Cidr) (map[domain.Cidr][]domain.Cidr, error) +} + +type InterfaceController interface { + GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error) + GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error) + GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error) + SaveInterface( + _ context.Context, + id domain.InterfaceIdentifier, + updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error), + ) error + DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error + SavePeer( + _ context.Context, + deviceId domain.InterfaceIdentifier, + id domain.PeerIdentifier, + updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error), + ) error + DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error +} + +type WgQuickController interface { + ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error + SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error + UnsetDNS(id domain.InterfaceIdentifier) error +} + +type EventBus interface { + // Publish sends a message to the message bus. + Publish(topic string, args ...any) + // Subscribe subscribes to a topic + Subscribe(topic string, fn interface{}) error +} + +// endregion dependencies + +type Manager struct { + cfg *config.Config + bus EventBus db InterfaceAndPeerDatabaseRepo wg InterfaceController quick WgQuickController @@ -23,7 +80,7 @@ type Manager struct { func NewWireGuardManager( cfg *config.Config, - bus evbus.MessageBus, + bus EventBus, wg InterfaceController, quick WgQuickController, db InterfaceAndPeerDatabaseRepo, @@ -41,6 +98,8 @@ func NewWireGuardManager( return m, nil } +// StartBackgroundJobs starts background jobs like the expired peers check. +// This method is non-blocking. func (m Manager) StartBackgroundJobs(ctx context.Context) { go m.runExpiredPeersCheck(ctx) } diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index ae6f0aa..703ecf9 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -13,6 +13,8 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +// GetImportableInterfaces returns all physical interfaces that are available on the system. +// This function also returns interfaces that are already available in the database. func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.PhysicalInterface, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -26,6 +28,7 @@ func (m Manager) GetImportableInterfaces(ctx context.Context) ([]domain.Physical return physicalInterfaces, nil } +// GetInterfaceAndPeers returns the interface and all peers for the given interface identifier. func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceIdentifier) ( *domain.Interface, []domain.Peer, @@ -38,6 +41,7 @@ func (m Manager) GetInterfaceAndPeers(ctx context.Context, id domain.InterfaceId return m.db.GetInterfaceAndPeers(ctx, id) } +// GetAllInterfaces returns all interfaces that are available in the database. func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -46,6 +50,7 @@ func (m Manager) GetAllInterfaces(ctx context.Context) ([]domain.Interface, erro return m.db.GetAllInterfaces(ctx) } +// GetAllInterfacesAndPeers returns all interfaces and their peers. func (m Manager) GetAllInterfacesAndPeers(ctx context.Context) ([]domain.Interface, [][]domain.Peer, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, nil, err @@ -97,6 +102,7 @@ func (m Manager) GetUserInterfaces(ctx context.Context, _ domain.UserIdentifier) return userInterfaces, nil } +// ImportNewInterfaces imports all new physical interfaces that are available on the system. func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.InterfaceIdentifier) (int, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return 0, err @@ -148,6 +154,7 @@ func (m Manager) ImportNewInterfaces(ctx context.Context, filter ...domain.Inter return imported, nil } +// ApplyPeerDefaults applies the interface defaults to all peers of the given interface. func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) error { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return err @@ -179,6 +186,8 @@ func (m Manager) ApplyPeerDefaults(ctx context.Context, in *domain.Interface) er return nil } +// RestoreInterfaceState restores the state of all physical interfaces and their peers. +// The final state of the interfaces and peers will be the same as stored in the database. func (m Manager) RestoreInterfaceState( ctx context.Context, updateDbOnError bool, @@ -296,6 +305,7 @@ func (m Manager) RestoreInterfaceState( return nil } +// PrepareInterface generates a new interface with fresh keys, ip addresses and a listen port. func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -376,6 +386,7 @@ func (m Manager) PrepareInterface(ctx context.Context) (*domain.Interface, error return freshInterface, nil } +// CreateInterface creates a new interface with the given configuration. func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, err @@ -401,6 +412,7 @@ func (m Manager) CreateInterface(ctx context.Context, in *domain.Interface) (*do return in, nil } +// UpdateInterface updates the given interface with the new configuration. func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*domain.Interface, []domain.Peer, error) { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return nil, nil, err @@ -423,6 +435,7 @@ func (m Manager) UpdateInterface(ctx context.Context, in *domain.Interface) (*do return in, existingPeers, nil } +// DeleteInterface deletes the given interface. func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentifier) error { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return err diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index 8418d85..8f3c9bf 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -11,6 +11,7 @@ import ( "github.com/h44z/wg-portal/internal/domain" ) +// CreateDefaultPeer creates a default peer for the given user on all server interfaces. func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdentifier) error { if err := domain.ValidateAdminAccessRights(ctx); err != nil { return err @@ -55,6 +56,7 @@ func (m Manager) CreateDefaultPeer(ctx context.Context, userId domain.UserIdenti return nil } +// GetUserPeers returns all peers for the given user. func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) { if err := domain.ValidateUserAccessRights(ctx, id); err != nil { return nil, err @@ -63,6 +65,7 @@ func (m Manager) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([] return m.db.GetUserPeers(ctx, id) } +// PreparePeer prepares a new peer for the given interface with fresh keys and ip addresses. func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Peer, error) { if !m.cfg.Core.SelfProvisioningAllowed { if err := domain.ValidateAdminAccessRights(ctx); err != nil { @@ -143,6 +146,7 @@ func (m Manager) PreparePeer(ctx context.Context, id domain.InterfaceIdentifier) return freshPeer, nil } +// GetPeer returns the peer with the given identifier. func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain.Peer, error) { peer, err := m.db.GetPeer(ctx, id) if err != nil { @@ -156,6 +160,7 @@ func (m Manager) GetPeer(ctx context.Context, id domain.PeerIdentifier) (*domain return peer, nil } +// CreatePeer creates a new peer. func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { if !m.cfg.Core.SelfProvisioningAllowed { if err := domain.ValidateAdminAccessRights(ctx); err != nil { @@ -201,6 +206,8 @@ func (m Manager) CreatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee return peer, nil } +// CreateMultiplePeers creates multiple new peers for the given user identifiers. +// It calls PreparePeer for each user identifier in the request. func (m Manager) CreateMultiplePeers( ctx context.Context, interfaceId domain.InterfaceIdentifier, @@ -243,6 +250,7 @@ func (m Manager) CreateMultiplePeers( return createdPeers, nil } +// UpdatePeer updates the given peer. func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Peer, error) { existingPeer, err := m.db.GetPeer(ctx, peer.Identifier) if err != nil { @@ -309,6 +317,7 @@ func (m Manager) UpdatePeer(ctx context.Context, peer *domain.Peer) (*domain.Pee return peer, nil } +// DeletePeer deletes the peer with the given identifier. func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error { peer, err := m.db.GetPeer(ctx, id) if err != nil { @@ -341,6 +350,7 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error return nil } +// GetPeerStats returns the status of the peer with the given identifier. func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.PeerStatus, error) { _, peers, err := m.db.GetInterfaceAndPeers(ctx, id) if err != nil { @@ -359,6 +369,7 @@ func (m Manager) GetPeerStats(ctx context.Context, id domain.InterfaceIdentifier return m.db.GetPeersStats(ctx, peerIds...) } +// GetUserPeerStats returns the status of all peers for the given user. func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) ([]domain.PeerStatus, error) { if err := domain.ValidateUserAccessRights(ctx, id); err != nil { return nil, err diff --git a/internal/config/auth.go b/internal/config/auth.go index d115aeb..7e70dae 100644 --- a/internal/config/auth.go +++ b/internal/config/auth.go @@ -8,6 +8,7 @@ import ( "github.com/go-ldap/ldap/v3" ) +// Auth contains all authentication providers. type Auth struct { // OpenIDConnect contains a list of OpenID Connect providers. OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"` @@ -17,6 +18,7 @@ type Auth struct { Ldap []LdapProvider `yaml:"ldap"` } +// BaseFields contains the basic fields that are used to map user information from the authentication providers. type BaseFields struct { // UserIdentifier is the name of the field that contains the user identifier. UserIdentifier string `yaml:"user_identifier"` @@ -32,6 +34,7 @@ type BaseFields struct { Department string `yaml:"department"` } +// OauthFields contains extra fields that are used to map user information from OAuth providers. type OauthFields struct { BaseFields `yaml:",inline"` // IsAdmin is the name of the field that contains the admin flag. @@ -107,12 +110,14 @@ func (o *OauthAdminMapping) GetAdminGroupRegex() *regexp.Regexp { return o.adminGroupRegex } +// LdapFields contains extra fields that are used to map user information from LDAP providers. type LdapFields struct { BaseFields `yaml:",inline"` // GroupMembership is the name of the LDAP field that contains the groups to which the user belongs. GroupMembership string `yaml:"memberof"` } +// LdapProvider contains the configuration for the LDAP connection. type LdapProvider struct { // ProviderName is an internal name that is used to distinguish LDAP servers. It must not contain spaces or special characters. ProviderName string `yaml:"provider_name"` @@ -163,6 +168,7 @@ type LdapProvider struct { LogUserInfo bool `yaml:"log_user_info"` } +// OpenIDConnectProvider contains the configuration for the OpenID Connect provider. type OpenIDConnectProvider struct { // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. ProviderName string `yaml:"provider_name"` @@ -196,6 +202,7 @@ type OpenIDConnectProvider struct { LogUserInfo bool `yaml:"log_user_info"` } +// OAuthProvider contains the configuration for the OAuth provider. type OAuthProvider struct { // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. ProviderName string `yaml:"provider_name"` diff --git a/internal/config/config.go b/internal/config/config.go index 5a756ad..0bda0c8 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,6 +10,7 @@ import ( "gopkg.in/yaml.v3" ) +// Config is the main configuration struct. type Config struct { Core struct { // AdminUser defines the default administrator account that will be created @@ -179,6 +180,7 @@ func GetConfig() (*Config, error) { return cfg, nil } +// loadConfigFile loads the configuration from a YAML file into the given cfg struct. func loadConfigFile(cfg any, filename string) error { data, err := envsubst.ReadFile(filename) if err != nil { diff --git a/internal/config/database.go b/internal/config/database.go index c04fdd6..a336683 100644 --- a/internal/config/database.go +++ b/internal/config/database.go @@ -2,6 +2,8 @@ package config import "time" +// SupportedDatabase is a type for the supported database types. +// Supported: mysql, mssql, postgres, sqlite type SupportedDatabase string const ( @@ -11,6 +13,7 @@ const ( DatabaseSQLite SupportedDatabase = "sqlite" ) +// DatabaseConfig contains the configuration for the database connection. type DatabaseConfig struct { // Debug enables logging of all database statements Debug bool `yaml:"debug"` diff --git a/internal/config/mail.go b/internal/config/mail.go index ff9184a..9cc2b5b 100644 --- a/internal/config/mail.go +++ b/internal/config/mail.go @@ -1,5 +1,7 @@ package config +// MailEncryption is the type of the SMTP encryption. +// Supported: none, tls, starttls type MailEncryption string const ( @@ -8,6 +10,8 @@ const ( MailEncryptionStartTLS MailEncryption = "starttls" ) +// MailAuthType is the type of the SMTP authentication. +// Supported: plain, login, crammd5 type MailAuthType string const ( @@ -16,6 +20,7 @@ const ( MailAuthCramMD5 MailAuthType = "crammd5" ) +// MailConfig contains the configuration for the mail server which is used to send emails. type MailConfig struct { // Host is the hostname or IP of the SMTP server Host string `yaml:"host"` diff --git a/internal/config/web.go b/internal/config/web.go index 5327d6c..1743305 100644 --- a/internal/config/web.go +++ b/internal/config/web.go @@ -1,5 +1,6 @@ package config +// WebConfig contains the configuration for the web server. type WebConfig struct { // RequestLogging enables logging of all HTTP requests. RequestLogging bool `yaml:"request_logging"` diff --git a/internal/domain/auth.go b/internal/domain/auth.go index f501ec7..817d66d 100644 --- a/internal/domain/auth.go +++ b/internal/domain/auth.go @@ -1,11 +1,5 @@ package domain -import ( - "context" - - "golang.org/x/oauth2" -) - type LoginProvider string type LoginProviderInfo struct { @@ -24,28 +18,3 @@ type AuthenticatorUserInfo struct { Department string IsAdmin bool } - -type AuthenticatorType string - -const ( - AuthenticatorTypeOAuth AuthenticatorType = "oauth" - AuthenticatorTypeOidc AuthenticatorType = "oidc" -) - -type OauthAuthenticator interface { - GetName() string - GetType() AuthenticatorType - AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string - Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) - GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]any, error) - ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error) - RegistrationEnabled() bool -} - -type LdapAuthenticator interface { - GetName() string - PlaintextAuthentication(userId UserIdentifier, plainPassword string) error - GetUserInfo(ctx context.Context, username UserIdentifier) (map[string]any, error) - ParseUserInfo(raw map[string]any) (*AuthenticatorUserInfo, error) - RegistrationEnabled() bool -} diff --git a/internal/domain/crypto.go b/internal/domain/crypto.go index 165a107..3f7f303 100644 --- a/internal/domain/crypto.go +++ b/internal/domain/crypto.go @@ -33,6 +33,7 @@ func (p KeyPair) GetPublicKey() wgtypes.Key { type PreSharedKey string +// NewFreshKeypair generates a new key pair. func NewFreshKeypair() (KeyPair, error) { privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -45,6 +46,7 @@ func NewFreshKeypair() (KeyPair, error) { }, nil } +// NewPreSharedKey generates a new pre-shared key. func NewPreSharedKey() (PreSharedKey, error) { preSharedKey, err := wgtypes.GenerateKey() if err != nil { @@ -54,6 +56,8 @@ func NewPreSharedKey() (PreSharedKey, error) { return PreSharedKey(preSharedKey.String()), nil } +// PublicKeyFromPrivateKey returns the public key for a given private key. +// If the private key is invalid, an empty string is returned. func PublicKeyFromPrivateKey(key string) string { privKey, err := wgtypes.ParseKey(key) if err != nil { diff --git a/internal/domain/crypto_test.go b/internal/domain/crypto_test.go new file mode 100644 index 0000000..c0f5857 --- /dev/null +++ b/internal/domain/crypto_test.go @@ -0,0 +1,56 @@ +package domain + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func TestKeyPair_GetPrivateKeyBytesReturnsCorrectBytes(t *testing.T) { + keyPair := KeyPair{PrivateKey: base64.StdEncoding.EncodeToString([]byte("privateKey"))} + expected := []byte("privateKey") + assert.Equal(t, expected, keyPair.GetPrivateKeyBytes()) +} + +func TestKeyPair_GetPublicKeyBytesReturnsCorrectBytes(t *testing.T) { + keyPair := KeyPair{PublicKey: base64.StdEncoding.EncodeToString([]byte("publicKey"))} + expected := []byte("publicKey") + assert.Equal(t, expected, keyPair.GetPublicKeyBytes()) +} + +func TestKeyPair_GetPrivateKeyReturnsCorrectKey(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + keyPair := KeyPair{PrivateKey: privateKey.String()} + assert.Equal(t, privateKey, keyPair.GetPrivateKey()) +} + +func TestKeyPair_GetPublicKeyReturnsCorrectKey(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + keyPair := KeyPair{PublicKey: privateKey.PublicKey().String()} + assert.Equal(t, privateKey.PublicKey(), keyPair.GetPublicKey()) +} + +func TestNewFreshKeypairGeneratesValidKeypair(t *testing.T) { + keyPair, err := NewFreshKeypair() + assert.NoError(t, err) + assert.NotEmpty(t, keyPair.PrivateKey) + assert.NotEmpty(t, keyPair.PublicKey) +} + +func TestNewPreSharedKeyGeneratesValidKey(t *testing.T) { + preSharedKey, err := NewPreSharedKey() + assert.NoError(t, err) + assert.NotEmpty(t, preSharedKey) +} + +func TestPublicKeyFromPrivateKeyReturnsCorrectPublicKey(t *testing.T) { + privateKey, _ := wgtypes.GeneratePrivateKey() + expected := privateKey.PublicKey().String() + assert.Equal(t, expected, PublicKeyFromPrivateKey(privateKey.String())) +} + +func TestPublicKeyFromPrivateKeyReturnsEmptyStringOnInvalidKey(t *testing.T) { + assert.Equal(t, "", PublicKeyFromPrivateKey("invalidKey")) +} diff --git a/internal/domain/interface_test.go b/internal/domain/interface_test.go new file mode 100644 index 0000000..54aa74d --- /dev/null +++ b/internal/domain/interface_test.go @@ -0,0 +1,83 @@ +package domain + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) { + iface := &Interface{} + assert.False(t, iface.IsDisabled()) + + now := time.Now() + iface.Disabled = &now + assert.True(t, iface.IsDisabled()) +} + +func TestInterface_AddressStrReturnsCorrectString(t *testing.T) { + iface := &Interface{ + Addresses: []Cidr{ + {Cidr: "192.168.1.1/24", Addr: "192.168.1.1", NetLength: 24}, + {Cidr: "10.0.0.1/24", Addr: "10.0.0.1", NetLength: 24}, + }, + } + expected := "192.168.1.1/24,10.0.0.1/24" + assert.Equal(t, expected, iface.AddressStr()) +} + +func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) { + iface := &Interface{Identifier: "wg0"} + expected := "wg0.conf" + assert.Equal(t, expected, iface.GetConfigFileName()) + + iface.Identifier = "wg0@123" + expected = "wg0123.conf" + assert.Equal(t, expected, iface.GetConfigFileName()) +} + +func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) { + peer1 := Peer{ + Interface: PeerInterfaceConfig{ + Addresses: []Cidr{ + {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, + }, + }, + } + peer2 := Peer{ + Interface: PeerInterfaceConfig{ + Addresses: []Cidr{ + {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, + }, + }, + } + iface := &Interface{} + expected := []Cidr{ + {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32}, + {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32}, + } + assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2})) +} + +func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) { + iface := &Interface{RoutingTable: "off"} + assert.False(t, iface.ManageRoutingTable()) + + iface.RoutingTable = "100" + assert.True(t, iface.ManageRoutingTable()) +} + +func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) { + iface := &Interface{RoutingTable: ""} + assert.Equal(t, 0, iface.GetRoutingTable()) + + iface.RoutingTable = "off" + assert.Equal(t, -1, iface.GetRoutingTable()) + + iface.RoutingTable = "0x64" + assert.Equal(t, 100, iface.GetRoutingTable()) + + iface.RoutingTable = "200" + assert.Equal(t, 200, iface.GetRoutingTable()) +} diff --git a/internal/domain/options_test.go b/internal/domain/options_test.go new file mode 100644 index 0000000..2c4ab4c --- /dev/null +++ b/internal/domain/options_test.go @@ -0,0 +1,42 @@ +package domain + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConfigOption_GetValueReturnsCorrectValue(t *testing.T) { + option := ConfigOption[int]{Value: 42} + assert.Equal(t, 42, option.GetValue()) +} + +func TestConfigOption_SetValueUpdatesValue(t *testing.T) { + option := ConfigOption[int]{Value: 42} + option.SetValue(100) + assert.Equal(t, 100, option.GetValue()) +} + +func TestConfigOption_TrySetValueUpdatesValueWhenOverridable(t *testing.T) { + option := ConfigOption[int]{Value: 42, Overridable: true} + result := option.TrySetValue(100) + assert.True(t, result) + assert.Equal(t, 100, option.GetValue()) +} + +func TestConfigOption_TrySetValueDoesNotUpdateValueWhenNotOverridable(t *testing.T) { + option := ConfigOption[int]{Value: 42, Overridable: false} + result := option.TrySetValue(100) + assert.False(t, result) + assert.Equal(t, 42, option.GetValue()) +} + +func TestNewConfigOptionCreatesCorrectOption(t *testing.T) { + option := NewConfigOption(42, true) + assert.Equal(t, 42, option.GetValue()) + assert.True(t, option.Overridable) + + option2 := NewConfigOption("str", false) + assert.Equal(t, "str", option2.GetValue()) + assert.False(t, option2.Overridable) +} diff --git a/internal/domain/peer_test.go b/internal/domain/peer_test.go new file mode 100644 index 0000000..856f29b --- /dev/null +++ b/internal/domain/peer_test.go @@ -0,0 +1,165 @@ +package domain + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestPeer_IsDisabled(t *testing.T) { + peer := &Peer{} + assert.False(t, peer.IsDisabled()) + + now := time.Now() + peer.Disabled = &now + assert.True(t, peer.IsDisabled()) +} + +func TestPeer_IsExpired(t *testing.T) { + peer := &Peer{} + assert.False(t, peer.IsExpired()) + + expiredTime := time.Now().Add(-time.Hour) + peer.ExpiresAt = &expiredTime + assert.True(t, peer.IsExpired()) + + futureTime := time.Now().Add(time.Hour) + peer.ExpiresAt = &futureTime + assert.False(t, peer.IsExpired()) +} + +func TestPeer_CheckAliveAddress(t *testing.T) { + peer := &Peer{} + assert.Equal(t, "", peer.CheckAliveAddress()) + + peer.Interface.CheckAliveAddress = "192.168.1.1" + assert.Equal(t, "192.168.1.1", peer.CheckAliveAddress()) + + peer.Interface.CheckAliveAddress = "" + peer.Interface.Addresses = []Cidr{{Addr: "10.0.0.1"}} + assert.Equal(t, "10.0.0.1", peer.CheckAliveAddress()) +} + +func TestPeer_GetConfigFileName(t *testing.T) { + peer := &Peer{DisplayName: "Test Peer"} + expected := "Test_Peer.conf" + assert.Equal(t, expected, peer.GetConfigFileName()) + + peer.DisplayName = "" + peer.Identifier = "12345678" + expected = "wg_12345678.conf" + assert.Equal(t, expected, peer.GetConfigFileName()) +} + +func TestPeer_ApplyInterfaceDefaults(t *testing.T) { + peer := &Peer{ + Endpoint: ConfigOption[string]{ + Value: "", + Overridable: true, + }, + EndpointPublicKey: ConfigOption[string]{ + Value: "", + Overridable: true, + }, + AllowedIPsStr: ConfigOption[string]{ + Value: "1.1.1.1/32", + Overridable: false, + }, + } + iface := &Interface{ + PeerDefEndpoint: "192.168.1.1", + KeyPair: KeyPair{ + PublicKey: "publicKey", + }, + PeerDefAllowedIPsStr: "8.8.8.8/32", + } + + peer.ApplyInterfaceDefaults(iface) + assert.Equal(t, "192.168.1.1", peer.Endpoint.GetValue()) + assert.Equal(t, "publicKey", peer.EndpointPublicKey.GetValue()) + assert.Equal(t, "1.1.1.1/32", peer.AllowedIPsStr.GetValue()) +} + +func TestPeer_GenerateDisplayName(t *testing.T) { + peer := &Peer{Identifier: "12345678"} + peer.GenerateDisplayName("Prefix") + expected := "Prefix Peer 12345678" + assert.Equal(t, expected, peer.DisplayName) + + peer.GenerateDisplayName("") + expected = "Peer 12345678" + assert.Equal(t, expected, peer.DisplayName) +} + +func TestPeer_OverwriteUserEditableFields(t *testing.T) { + peer := &Peer{} + userPeer := &Peer{ + DisplayName: "New DisplayName", + } + + peer.OverwriteUserEditableFields(userPeer) + assert.Equal(t, "New DisplayName", peer.DisplayName) +} + +func TestPeer_GetPresharedKey(t *testing.T) { + physicalPeer := PhysicalPeer{} + assert.Nil(t, physicalPeer.GetPresharedKey()) + + physicalPeer.PresharedKey = "Q0evIJTOjhyy2o5J7whvrsvQC+FRL8A74vrw44YHUAk=" + key := physicalPeer.GetPresharedKey() + assert.NotNil(t, key) +} + +func TestPeer_GetEndpointAddress(t *testing.T) { + physicalPeer := PhysicalPeer{} + assert.Nil(t, physicalPeer.GetEndpointAddress()) + + physicalPeer.Endpoint = "192.168.1.1:51820" + addr := physicalPeer.GetEndpointAddress() + assert.NotNil(t, addr) + assert.Equal(t, "192.168.1.1:51820", addr.String()) +} + +func TestPeer_GetPersistentKeepaliveTime(t *testing.T) { + physicalPeer := PhysicalPeer{} + assert.Nil(t, physicalPeer.GetPersistentKeepaliveTime()) + + physicalPeer.PersistentKeepalive = 25 + duration := physicalPeer.GetPersistentKeepaliveTime() + assert.NotNil(t, duration) + assert.Equal(t, 25*time.Second, *duration) +} + +func TestPeer_GetAllowedIPs(t *testing.T) { + physicalPeer := PhysicalPeer{} + assert.Empty(t, physicalPeer.GetAllowedIPs()) + + physicalPeer.AllowedIPs = []Cidr{ + { + Cidr: "192.168.1.0/24", + Addr: "192.168.1.0", + NetLength: 24, + }, + } + ips := physicalPeer.GetAllowedIPs() + assert.Len(t, ips, 1) + assert.Equal(t, "192.168.1.0/24", ips[0].String()) + + physicalPeer.AllowedIPs = []Cidr{ + { + Cidr: "192.168.1.0/24", + Addr: "192.168.1.0", + NetLength: 24, + }, + { + Cidr: "fe80::/64", + Addr: "fe80::", + NetLength: 64, + }, + } + ips2 := physicalPeer.GetAllowedIPs() + assert.Len(t, ips2, 2) + assert.Equal(t, "192.168.1.0/24", ips2[0].String()) + assert.Equal(t, "fe80::/64", ips2[1].String()) +} diff --git a/internal/domain/statistics_test.go b/internal/domain/statistics_test.go new file mode 100644 index 0000000..74e4678 --- /dev/null +++ b/internal/domain/statistics_test.go @@ -0,0 +1,74 @@ +package domain + +import ( + "testing" + "time" +) + +func TestPeerStatus_IsConnected(t *testing.T) { + now := time.Now() + past := now.Add(-3 * time.Minute) + recent := now.Add(-1 * time.Minute) + + tests := []struct { + name string + status PeerStatus + want bool + }{ + { + name: "Pingable and recent handshake", + status: PeerStatus{ + IsPingable: true, + LastHandshake: &recent, + }, + want: true, + }, + { + name: "Not pingable but recent handshake", + status: PeerStatus{ + IsPingable: false, + LastHandshake: &recent, + }, + want: true, + }, + { + name: "Pingable but old handshake", + status: PeerStatus{ + IsPingable: true, + LastHandshake: &past, + }, + want: true, + }, + { + name: "Not pingable and old handshake", + status: PeerStatus{ + IsPingable: false, + LastHandshake: &past, + }, + want: false, + }, + { + name: "Pingable and no handshake", + status: PeerStatus{ + IsPingable: true, + LastHandshake: nil, + }, + want: true, + }, + { + name: "Not pingable and no handshake", + status: PeerStatus{ + IsPingable: false, + LastHandshake: nil, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.status.IsConnected(); got != tt.want { + t.Errorf("IsConnected() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/domain/user_test.go b/internal/domain/user_test.go new file mode 100644 index 0000000..e786f39 --- /dev/null +++ b/internal/domain/user_test.go @@ -0,0 +1,125 @@ +package domain + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" +) + +func TestUser_IsDisabled(t *testing.T) { + user := &User{} + assert.False(t, user.IsDisabled()) + + now := time.Now() + user.Disabled = &now + assert.True(t, user.IsDisabled()) +} + +func TestUser_IsLocked(t *testing.T) { + user := &User{} + assert.False(t, user.IsLocked()) + + now := time.Now() + user.Locked = &now + assert.True(t, user.IsLocked()) +} + +func TestUser_IsApiEnabled(t *testing.T) { + user := &User{} + assert.False(t, user.IsApiEnabled()) + + user.ApiToken = "token" + assert.True(t, user.IsApiEnabled()) +} + +func TestUser_CanChangePassword(t *testing.T) { + user := &User{Source: UserSourceDatabase} + assert.NoError(t, user.CanChangePassword()) + + user.Source = UserSourceLdap + assert.Error(t, user.CanChangePassword()) + + user.Source = UserSourceOauth + assert.Error(t, user.CanChangePassword()) +} + +func TestUser_EditAllowed(t *testing.T) { + user := &User{Source: UserSourceDatabase} + newUser := &User{Source: UserSourceDatabase} + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Notes = "notes can be changed" + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Disabled = &time.Time{} + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Lastname = "lastname or other fields can be changed" + assert.NoError(t, user.EditAllowed(newUser)) + + user.Source = UserSourceLdap + newUser.Source = UserSourceLdap + newUser.Disabled = nil + newUser.Lastname = "" + newUser.Notes = "notes can be changed" + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Disabled = &time.Time{} + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Lastname = "lastname or other fields can not be changed" + assert.Error(t, user.EditAllowed(newUser)) + + user.Source = UserSourceOauth + newUser.Source = UserSourceOauth + newUser.Disabled = nil + newUser.Lastname = "" + newUser.Notes = "notes can be changed" + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Disabled = &time.Time{} + assert.NoError(t, user.EditAllowed(newUser)) + + newUser.Lastname = "lastname or other fields can not be changed" + assert.Error(t, user.EditAllowed(newUser)) +} + +func TestUser_DeleteAllowed(t *testing.T) { + user := &User{} + assert.NoError(t, user.DeleteAllowed()) +} + +func TestUser_CheckPassword(t *testing.T) { + password := "password" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + + user := &User{Source: UserSourceDatabase, Password: PrivateString(hashedPassword)} + assert.NoError(t, user.CheckPassword(password)) + + user.Password = "" + assert.Error(t, user.CheckPassword(password)) + + user.Source = UserSourceLdap + assert.Error(t, user.CheckPassword(password)) +} + +func TestUser_CheckApiToken(t *testing.T) { + user := &User{} + assert.Error(t, user.CheckApiToken("token")) + + user.ApiToken = "token" + assert.NoError(t, user.CheckApiToken("token")) + + assert.Error(t, user.CheckApiToken("wrong_token")) +} + +func TestUser_HashPassword(t *testing.T) { + user := &User{Password: "password"} + assert.NoError(t, user.HashPassword()) + assert.NotEmpty(t, user.Password) + + user.Password = "" + assert.NoError(t, user.HashPassword()) +}