Compare commits

...

14 Commits

Author SHA1 Message Date
dependabot[bot] da76327569
chore(deps): bump golang.org/x/oauth2 from 0.31.0 to 0.32.0 (#544)
Bumps [golang.org/x/oauth2](https://github.com/golang/oauth2) from 0.31.0 to 0.32.0.
- [Commits](https://github.com/golang/oauth2/compare/v0.31.0...v0.32.0)

---
updated-dependencies:
- dependency-name: golang.org/x/oauth2
  dependency-version: 0.32.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-15 21:14:22 +02:00
dependabot[bot] c154cb3977
chore(deps): bump golang.org/x/sys from 0.36.0 to 0.37.0 (#545)
Bumps [golang.org/x/sys](https://github.com/golang/sys) from 0.36.0 to 0.37.0.
- [Commits](https://github.com/golang/sys/compare/v0.36.0...v0.37.0)

---
updated-dependencies:
- dependency-name: golang.org/x/sys
  dependency-version: 0.37.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-15 21:12:15 +02:00
dependabot[bot] 7bca35728d
chore(deps): bump golang.org/x/crypto from 0.42.0 to 0.43.0 (#546)
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.42.0 to 0.43.0.
- [Commits](https://github.com/golang/crypto/compare/v0.42.0...v0.43.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-version: 0.43.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-10-15 21:12:04 +02:00
h44z 3d923b328e
password change UI (#543) (#548) 2025-10-15 21:11:40 +02:00
Christoph Haas 139fb17f98
redo UI screenshots, fix the responsiveness of the image slider for wgportal.org 2025-10-12 15:48:08 +02:00
Christoph Haas faf1d995a8
fix parsing IP addresses in UI (ip-address lib was updated to V10) 2025-10-12 15:20:38 +02:00
Christoph Haas f53d0b3d7f
add the possibility to debug oauth or oidc login issues (#541) 2025-10-12 15:09:40 +02:00
h44z cdf3a49801
Cleanup route handling (#542)
* mikrotik: allow to set DNS, wip: handle routes in wg-controller

* replace old route handling for local controller

* cleanup route handling for local backend

* implement route handling for mikrotik controller
2025-10-12 14:31:19 +02:00
Christoph Haas 298c9405f6
add support for sending emails to peers without linked user accounts if their user-identifier is a valid email address 2025-10-12 14:31:01 +02:00
Christoph c7724b620a smaller UI improvements; add system theme 2025-10-12 00:48:35 +02:00
Christoph Haas 4d19f1d8bb
fix a small visual glitch in dialogs 2025-10-06 21:10:39 +02:00
Christoph Haas 3f539a1615
improve interface view: show correct DNS entries 2025-10-06 19:26:51 +02:00
PROLICHT GmbH 0305911467
Merge pull request #539 from h44z/dependabot/go_modules/github.com/go-playground/validator/v10-10.28.0
chore(deps): bump github.com/go-playground/validator/v10 from 10.27.0 to 10.28.0
2025-10-06 16:05:21 +02:00
dependabot[bot] 85f7a5a9a6
chore(deps): bump github.com/go-playground/validator/v10
Bumps [github.com/go-playground/validator/v10](https://github.com/go-playground/validator) from 10.27.0 to 10.28.0.
- [Release notes](https://github.com/go-playground/validator/releases)
- [Commits](https://github.com/go-playground/validator/compare/v10.27.0...v10.28.0)

---
updated-dependencies:
- dependency-name: github.com/go-playground/validator/v10
  dependency-version: 10.28.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-10-06 13:48:44 +00:00
44 changed files with 1788 additions and 1073 deletions

View File

@ -53,8 +53,6 @@ func main() {
wireGuard, err := wireguard.NewControllerManager(cfg)
internal.AssertNoError(err)
wgQuick := adapters.NewWgQuickRepo()
mailer := adapters.NewSmtpMailRepo(cfg.Mail)
metricsServer := adapters.NewMetricsServer(cfg)
@ -93,7 +91,7 @@ func main() {
webAuthn, err := auth.NewWebAuthnAuthenticator(cfg, eventBus, userManager)
internal.AssertNoError(err)
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, wgQuick, database)
wireGuardManager, err := wireguard.NewWireGuardManager(cfg, eventBus, wireGuard, database)
internal.AssertNoError(err)
wireGuardManager.StartBackgroundJobs(ctx)
@ -107,7 +105,7 @@ func main() {
mailManager, err := mail.NewMailManager(cfg, mailer, cfgFileManager, database, database)
internal.AssertNoError(err)
routeManager, err := route.NewRouteManager(cfg, eventBus, database)
routeManager, err := route.NewRouteManager(cfg, eventBus, database, wireGuard)
internal.AssertNoError(err)
routeManager.StartBackgroundJobs(ctx)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 130 KiB

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 129 KiB

After

Width:  |  Height:  |  Size: 131 KiB

View File

@ -15,6 +15,9 @@ backend:
# default backend decides where new interfaces are created
default: mikrotik
# A prefix for resolvconf. Usually it is "tun.". If you are using systemd, the prefix should be empty.
local_resolvconf_prefix: "tun."
mikrotik:
- id: mikrotik # unique id, not "local"
display_name: RouterOS RB5009 # optional nice name

View File

@ -28,6 +28,7 @@ core:
backend:
default: local
local_resolvconf_prefix: tun.
advanced:
log_level: info
@ -72,6 +73,7 @@ mail:
auth_type: plain
from: Wireguard Portal <noreply@wireguard.local>
link_only: false
allow_peer_email: false
auth:
oidc: []
@ -184,6 +186,11 @@ The current MikroTik backend is in **BETA** and may not support all features.
- **Description:** The default backend to use for managing WireGuard interfaces.
Valid options are: `local`, or other backend id's configured in the `mikrotik` section.
### `local_resolvconf_prefix`
- **Default:** `tun.`
- **Description:** Interface name prefix for WireGuard interfaces on the local system which is used to configure DNS servers with *resolvconf*.
It depends on the *resolvconf* implementation you are using, most use a prefix of `tun.`, but some have an empty prefix (e.g., systemd).
### `ignored_local_interfaces`
- **Default:** *(empty)*
- **Description:** A list of interface names to exclude when enumerating local interfaces.
@ -380,7 +387,9 @@ Controls how WireGuard Portal collects and reports usage statistics, including p
## Mail
Options for configuring email notifications or sending peer configurations via email.
Options for configuring email notifications or sending peer configurations via email.
By default, emails will only be sent to peers that have a valid user record linked.
To send emails to all peers that have a valid email-address as user-identifier, set `allow_peer_email` to `true`.
### `host`
- **Default:** `127.0.0.1`
@ -418,6 +427,12 @@ Options for configuring email notifications or sending peer configurations via e
- **Default:** `false`
- **Description:** If `true`, emails only contain a link to WireGuard Portal, rather than attaching the full configuration.
### `allow_peer_email`
- **Default:** `false`
- **Description:** If `true`, and a peer has no valid user record linked, but the user-identifier of the peer is a valid email address, emails will be sent to that email address.
If false, and the peer has no valid user record linked, emails will not be sent.
If a peer has linked a valid user, the email address is always taken from the user record.
---
## Auth
@ -497,13 +512,18 @@ Below are the properties for each OIDC provider entry inside `auth.oidc`:
- `admin_group_regex`: A regular expression to match the `user_groups` claim. Each entry in the `user_groups` claim is checked against this regex.
#### `registration_enabled`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, a new user will be created in WireGuard Portal if not already present.
#### `log_user_info`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, OIDC user data is logged at the trace level upon login (for debugging).
#### `log_sensitive_info`
- **Default:** `false`
- **Description:** If `true`, sensitive OIDC user data, such as tokens and raw responses, will be logged at the trace level upon login (for debugging).
- **Important:** Keep this setting disabled in production environments! Remove logs once you finished debugging authentication issues.
---
### OAuth
@ -570,13 +590,18 @@ Below are the properties for each OAuth provider entry inside `auth.oauth`:
- `admin_group_regex`: A regular expression to match the `user_groups` claim. Each entry in the `user_groups` claim is checked against this regex.
#### `registration_enabled`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, new users are created automatically on successful login.
#### `log_user_info`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, logs user info at the trace level upon login.
#### `log_sensitive_info`
- **Default:** `false`
- **Description:** If `true`, sensitive OIDC user data, such as tokens and raw responses, will be logged at the trace level upon login (for debugging).
- **Important:** Keep this setting disabled in production environments! Remove logs once you finished debugging authentication issues.
---
### LDAP
@ -593,11 +618,11 @@ Below are the properties for each LDAP provider entry inside `auth.ldap`:
- **Description:** The LDAP server URL (e.g., `ldap://srv-ad01.company.local:389`).
#### `start_tls`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, use STARTTLS to secure the LDAP connection.
#### `cert_validation`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, validate the LDAP servers TLS certificate.
#### `tls_certificate_path`
@ -667,19 +692,19 @@ Below are the properties for each LDAP provider entry inside `auth.ldap`:
```
#### `disable_missing`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, any user **not** found in LDAP (during sync) is disabled in WireGuard Portal.
#### `auto_re_enable`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, users that where disabled because they were missing (see `disable_missing`) will be re-enabled once they are found again.
#### `registration_enabled`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, new user accounts are created in WireGuard Portal upon first login.
#### `log_user_info`
- **Default:** *(empty)*
- **Default:** `false`
- **Description:** If `true`, logs LDAP user data at the trace level upon login.
---

View File

@ -68,7 +68,7 @@
}
.tx-hero__image {
max-width: 1000px;
min-width: 600px;
min-width: 0;
width: 100%;
height: auto;
margin: 0 auto;
@ -218,7 +218,7 @@
.secondary-section .g .section .component-wrapper .responsive-grid .card {
position: relative;
background-color: #fff none repeat scroll 0% 0%;
background-color: #fff;
padding: 1.5rem;
display: flex;
flex-direction: row;
@ -363,7 +363,6 @@
<h1>A beautiful and simple UI to manage your WireGuard peers and interfaces</h1>
<p>WireGuard Portal is an open source web-based user interface that makes it easy to setup and manage
WireGuard VPN connections. It's built on top of WireGuard's official <span class="em">wgctrl</span> library.</p>
</p>
<a
href="documentation/overview/"
title="Get Started"

View File

@ -1,6 +1,6 @@
<script setup>
import { RouterLink, RouterView } from 'vue-router';
import { computed, getCurrentInstance, onMounted, ref } from "vue";
import {computed, getCurrentInstance, nextTick, onMounted, ref} from "vue";
import { authStore } from "./stores/auth";
import { securityStore } from "./stores/security";
import { settingsStore } from "@/stores/settings";
@ -11,12 +11,13 @@ const auth = authStore()
const sec = securityStore()
const settings = settingsStore()
const currentTheme = ref("auto")
onMounted(async () => {
console.log("Starting WireGuard Portal frontend...");
// restore theme from localStorage
const theme = localStorage.getItem('wgTheme') || 'light';
document.documentElement.setAttribute('data-bs-theme', theme);
switchTheme(getTheme());
await sec.LoadSecurityProperties();
await auth.LoadProviders();
@ -44,10 +45,22 @@ const switchLanguage = function (lang) {
}
}
const getTheme = function () {
return localStorage.getItem('wgTheme') || 'auto';
}
const switchTheme = function (theme) {
if (document.documentElement.getAttribute('data-bs-theme') !== theme) {
let bsTheme = theme;
if (theme === 'auto') {
bsTheme = window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
}
currentTheme.value = theme;
if (document.documentElement.getAttribute('data-bs-theme') !== bsTheme) {
console.log("Switching theme to " + theme + " (" + bsTheme + ")");
localStorage.setItem('wgTheme', theme);
document.documentElement.setAttribute('data-bs-theme', theme);
document.documentElement.setAttribute('data-bs-theme', bsTheme);
}
}
@ -137,20 +150,25 @@ const userDisplayName = computed(() => {
<div v-if="!auth.IsAuthenticated" class="nav-item">
<RouterLink :to="{ name: 'login' }" class="nav-link"><i class="fas fa-sign-in-alt fa-sm fa-fw me-2"></i>{{ $t('menu.login') }}</RouterLink>
</div>
<div class="nav-item dropdown" data-bs-theme="light">
<div class="nav-item dropdown" :key="currentTheme">
<a class="nav-link dropdown-toggle d-flex align-items-center" href="#" id="theme-menu" aria-expanded="false" data-bs-toggle="dropdown" data-bs-display="static" aria-label="Toggle theme">
<i class="fa-solid fa-circle-half-stroke"></i>
<span class="d-lg-none ms-2">Toggle theme</span>
</a>
<ul class="dropdown-menu dropdown-menu-end">
<li>
<button type="button" class="dropdown-item d-flex align-items-center" @click.prevent="switchTheme('auto')" aria-pressed="false">
<i class="fa-solid fa-circle-half-stroke"></i><span class="ms-2">System</span><i class="fa-solid fa-check ms-5" :class="{invisible:currentTheme!=='auto'}"></i>
</button>
</li>
<li>
<button type="button" class="dropdown-item d-flex align-items-center" @click.prevent="switchTheme('light')" aria-pressed="false">
<i class="fa-solid fa-sun"></i><span class="ms-2">Light</span>
<i class="fa-solid fa-sun"></i><span class="ms-2">Light</span><i class="fa-solid fa-check ms-5" :class="{invisible:currentTheme!=='light'}"></i>
</button>
</li>
<li>
<button type="button" class="dropdown-item d-flex align-items-center" @click.prevent="switchTheme('dark')" aria-pressed="true">
<i class="fa-solid fa-moon"></i><span class="ms-2">Dark</span>
<i class="fa-solid fa-moon"></i><span class="ms-2">Dark</span><i class="fa-solid fa-check ms-5" :class="{invisible:currentTheme!=='dark'}"></i>
</button>
</li>
</ul>
@ -221,4 +239,8 @@ const userDisplayName = computed(() => {
background-color: rgba(var(--bs-dark-rgb), var(--bs-bg-opacity)) !important;
color: var(--bs-badge-color)!important;
}
[data-bs-theme=dark] .navbar-dark, .navbar {
background-color: #000 !important;
}
</style>

View File

@ -104,4 +104,8 @@ a.disabled {
.vue-tags-input .ti-deletion-mark:after {
transform: scaleX(1);
}
.modal-dialog {
box-shadow: none;
}

View File

@ -316,6 +316,16 @@ async function del() {
isDeleting.value = true
try {
await interfaces.DeleteInterface(selectedInterface.value.Identifier)
// reload all interfaces and peers
await interfaces.LoadInterfaces()
if (interfaces.Count > 0 && interfaces.GetSelected !== undefined) {
const selectedInterface = interfaces.GetSelected
await peers.LoadPeers(selectedInterface.Identifier)
await peers.LoadStats(selectedInterface.Identifier)
} else {
await peers.Reset() // reset peers if no interfaces are available
}
close()
} catch (e) {
console.log(e)
@ -434,6 +444,11 @@ async function del() {
<label class="form-label mt-4">{{ $t('modals.interface-edit.firewall-mark.label') }}</label>
<input v-model="formData.FirewallMark" class="form-control" :placeholder="$t('modals.interface-edit.firewall-mark.placeholder')" type="number">
</div>
<div class="form-group col-md-6" v-if="formData.Backend!=='local'">
<label class="form-label mt-4">{{ $t('modals.interface-edit.routing-table.label') }}</label>
<input v-model="formData.RoutingTable" aria-describedby="routingTableHelp" class="form-control" :placeholder="$t('modals.interface-edit.routing-table.placeholder')" type="text">
<small id="routingTableHelp" class="form-text text-muted">{{ $t('modals.interface-edit.routing-table.description') }}</small>
</div>
<div class="form-group col-md-6" v-else>
</div>
</div>
@ -447,7 +462,7 @@ async function del() {
</div>
</div>
</fieldset>
<fieldset>
<fieldset v-if="formData.Backend==='local'">
<legend class="mt-4">{{ $t('modals.interface-edit.header-hooks') }}</legend>
<div class="form-group">
<label class="form-label mt-4">{{ $t('modals.interface-edit.pre-up.label') }}</label>
@ -472,7 +487,7 @@ async function del() {
<input v-model="formData.Disabled" class="form-check-input" type="checkbox">
<label class="form-check-label">{{ $t('modals.interface-edit.disabled.label') }}</label>
</div>
<div class="form-check form-switch">
<div class="form-check form-switch" v-if="formData.Backend==='local'">
<input v-model="formData.SaveConfig" checked="" class="form-check-input" type="checkbox">
<label class="form-check-label">{{ $t('modals.interface-edit.save-config.label') }}</label>
</div>

View File

@ -4,12 +4,12 @@ export function ipToBigInt(ip) {
// Check if it's an IPv4 address
if (ip.includes(".")) {
const addr = new Address4(ip)
return addr.bigInteger()
return addr.bigInt()
}
// Otherwise, assume it's an IPv6 address
const addr = new Address6(ip)
return addr.bigInteger()
return addr.bigInt()
}
export function humanFileSize(size) {

View File

@ -117,6 +117,7 @@
"dns": "DNS-Server",
"mtu": "MTU",
"default-keep-alive": "Standard Keepalive-Intervall",
"default-dns": "Standard DNS-Server",
"button-show-config": "Konfiguration anzeigen",
"button-download-config": "Konfiguration herunterladen",
"button-store-config": "Konfiguration für wg-quick speichern",
@ -220,6 +221,16 @@
"button-delete-text": "Passkey löschen. Sie können sich anschließend nicht mehr mit diesem Passkey anmelden.",
"button-register-title": "Passkey registrieren",
"button-register-text": "Einen neuen Passkey registrieren, um Ihr Konto zu sichern."
},
"password": {
"headline": "Passwort-Einstellungen",
"abstract": "Hier können Sie Ihr Passwort ändern.",
"current-label": "Aktuelles Passwort",
"new-label": "Neues Passwort",
"new-confirm-label": "Neues Passwort bestätigen",
"change-button-text": "Passwort ändern",
"invalid-confirm-label": "Passwörter stimmen nicht überein",
"weak-label": "Passwort ist zu schwach"
}
},
"audit": {

View File

@ -117,6 +117,7 @@
"dns": "DNS Servers",
"mtu": "MTU",
"default-keep-alive": "Default Keepalive Interval",
"default-dns": "Default DNS Servers",
"button-show-config": "Show configuration",
"button-download-config": "Download configuration",
"button-store-config": "Store configuration for wg-quick",
@ -220,6 +221,16 @@
"button-delete-text": "Delete the passkey. You will not be able to log in with this passkey anymore.",
"button-register-title": "Register Passkey",
"button-register-text": "Register a new Passkey to secure your account."
},
"password": {
"headline": "Password Settings",
"abstract": "Here you can change your password.",
"current-label": "Current Password",
"new-label": "New Password",
"new-confirm-label": "Confirm New Password",
"change-button-text": "Change Password",
"invalid-confirm-label": "Passwords do not match",
"weak-label": "Password is too weak"
}
},
"audit": {

View File

@ -115,6 +115,7 @@ export const interfaceStore = defineStore('interfaces', {
return apiWrapper.post(`${baseUrl}/new`, formData)
.then(iface => {
this.interfaces.push(iface)
this.selected = iface.Identifier
this.fetching = false
})
.catch(error => {

View File

@ -126,9 +126,14 @@ export const peerStore = defineStore('peers', {
if (!statsResponse) {
this.stats = {}
this.statsEnabled = false
} else {
this.stats = statsResponse.Stats
this.statsEnabled = statsResponse.Enabled
}
this.stats = statsResponse.Stats
this.statsEnabled = statsResponse.Enabled
},
async Reset() {
this.setPeers([])
this.setStats(undefined)
},
async PreparePeer(interfaceId) {
return apiWrapper.get(`${baseUrl}/iface/${base64_url_encode(interfaceId)}/prepare`)
@ -186,10 +191,10 @@ export const peerStore = defineStore('peers', {
async LoadStats(interfaceId) {
// if no interfaceId is given, use the currently selected interface
if (!interfaceId) {
interfaceId = interfaceStore().GetSelected.Identifier
if (!interfaceId) {
return // no interface, nothing to load
if (!interfaceStore().GetSelected || !interfaceStore().GetSelected.Identifier) {
return // no interface, nothing to load
}
interfaceId = interfaceStore().GetSelected.Identifier
}
this.fetching = true
@ -260,10 +265,10 @@ export const peerStore = defineStore('peers', {
async LoadPeers(interfaceId) {
// if no interfaceId is given, use the currently selected interface
if (!interfaceId) {
interfaceId = interfaceStore().GetSelected.Identifier
if (!interfaceId) {
if (!interfaceStore().GetSelected || !interfaceStore().GetSelected.Identifier) {
return // no interface, nothing to load
}
interfaceId = interfaceStore().GetSelected.Identifier
}
this.fetching = true

View File

@ -151,6 +151,17 @@ export const profileStore = defineStore('profile', {
})
})
},
async changePassword(formData) {
this.fetching = true
let currentUser = authStore().user.Identifier
return apiWrapper.post(`${baseUrl}/${base64_url_encode(currentUser)}/change-password`, formData)
.then(this.fetching = false)
.catch(error => {
this.fetching = false;
console.log("Failed to change password for ", currentUser, ": ", error);
throw new Error(error);
});
},
async LoadPeers() {
this.fetching = true
let currentUser = authStore().user.Identifier

View File

@ -217,14 +217,14 @@ onMounted(async () => {
<td>{{ $t('interfaces.interface.ip') }}:</td>
<td><span class="badge bg-light me-1" v-for="addr in interfaces.GetSelected.Addresses" :key="addr">{{addr}}</span></td>
</tr>
<tr>
<td>{{ $t('interfaces.interface.dns') }}:</td>
<td><span class="badge bg-light me-1" v-for="addr in interfaces.GetSelected.Dns" :key="addr">{{addr}}</span></td>
</tr>
<tr>
<td>{{ $t('interfaces.interface.mtu') }}:</td>
<td>{{interfaces.GetSelected.Mtu}}</td>
</tr>
<tr>
<td>{{ $t('interfaces.interface.default-dns') }}:</td>
<td><span class="badge bg-light me-1" v-for="addr in interfaces.GetSelected.PeerDefDns" :key="addr">{{addr}}</span></td>
</tr>
<tr>
<td>{{ $t('interfaces.interface.default-keep-alive') }}:</td>
<td>{{interfaces.GetSelected.PeerDefPersistentKeepalive}}</td>

View File

@ -1,8 +1,9 @@
<script setup>
import {onMounted, ref} from "vue";
import {computed, onMounted, ref} from "vue";
import { profileStore } from "@/stores/profile";
import { settingsStore } from "@/stores/settings";
import { authStore } from "../stores/auth";
import {notify} from "@kyvg/vue3-notification";
const profile = profileStore()
const settings = settingsStore()
@ -34,6 +35,45 @@ async function saveRename(credential) {
console.error("Failed to rename credential:", error);
}
}
const pwFormData = ref({
OldPassword: '',
Password: '',
PasswordRepeat: '',
})
const passwordWeak = computed(() => {
return pwFormData.value.Password && pwFormData.value.Password.length > 0 && pwFormData.value.Password.length < settings.Setting('MinPasswordLength')
})
const passwordChangeAllowed = computed(() => {
return pwFormData.value.Password && pwFormData.value.Password.length >= settings.Setting('MinPasswordLength') &&
pwFormData.value.Password === pwFormData.value.PasswordRepeat &&
pwFormData.value.OldPassword && pwFormData.value.OldPassword.length > 0 && pwFormData.value.OldPassword !== pwFormData.value.Password;
})
const updatePassword = async () => {
try {
await profile.changePassword(pwFormData.value);
pwFormData.value.OldPassword = '';
pwFormData.value.Password = '';
pwFormData.value.PasswordRepeat = '';
notify({
title: "Password changed!",
text: "Your password has been changed successfully.",
type: 'success',
});
} catch (e) {
notify({
title: "Failed to update password!",
text: e.toString(),
type: 'error',
})
}
}
</script>
<template>
@ -43,52 +83,45 @@ async function saveRename(credential) {
<p class="lead">{{ $t('settings.abstract') }}</p>
<div v-if="auth.IsAdmin || !settings.Setting('ApiAdminOnly')">
<div class="card border-secondary p-5" v-if="profile.user.ApiToken">
<h2 class="display-7">{{ $t('settings.api.headline') }}</h2>
<p class="lead">{{ $t('settings.api.abstract') }}</p>
<hr class="my-4">
<p>{{ $t('settings.api.active-description') }}</p>
<div class="row">
<div class="col-6">
<div class="form-group">
<label class="form-label mt-4">{{ $t('settings.api.user-label') }}</label>
<input v-model="profile.user.Identifier" class="form-control" :placeholder="$t('settings.api.user-placeholder')" type="text" readonly>
</div>
</div>
<div class="col-6">
<div class="form-group">
<label class="form-label mt-4">{{ $t('settings.api.token-label') }}</label>
<input v-model="profile.user.ApiToken" class="form-control" :placeholder="$t('settings.api.token-placeholder')" type="text" readonly>
</div>
<div class="card border-secondary p-5 mt-5" v-if="profile.user.Source === 'db'">
<h2 class="display-7">{{ $t('settings.password.headline') }}</h2>
<p class="lead">{{ $t('settings.password.abstract') }}</p>
<hr class="my-4">
<div class="row">
<div class="col-6">
<div class="form-group">
<label class="form-label mt-4" for="oldpw">{{ $t('settings.password.current-label') }}</label>
<input id="oldpw" v-model="pwFormData.OldPassword" class="form-control" :class="{ 'is-invalid': pwFormData.Password && !pwFormData.OldPassword }" type="password">
</div>
</div>
<div class="row">
<div class="col-12">
<div class="form-group">
<p class="form-label mt-4">{{ $t('settings.api.token-created-label') }} {{profile.user.ApiTokenCreated}}</p>
</div>
<div class="col-6">
</div>
</div>
<div class="row">
<div class="col-6">
<div class="form-group has-success">
<label class="form-label mt-4" for="newpw">{{ $t('settings.password.new-label') }}</label>
<input id="newpw" v-model="pwFormData.Password" class="form-control" :class="{ 'is-invalid': passwordWeak, 'is-valid': pwFormData.Password !== '' && !passwordWeak }" type="password">
<div class="invalid-feedback" v-if="passwordWeak">{{ $t('settings.password.weak-label') }}</div>
</div>
</div>
<div class="row mt-5">
<div class="col-6">
<button class="btn btn-primary" :title="$t('settings.api.button-disable-title')" @click.prevent="profile.disableApi()" :disabled="profile.isFetching">
<i class="fa-solid fa-minus-circle"></i> {{ $t('settings.api.button-disable-text') }}
</button>
</div>
<div class="col-6">
<a href="/api/v1/doc.html" target="_blank" :alt="$t('settings.api.api-link')">{{ $t('settings.api.api-link') }}</a>
<div class="col-6">
<div class="form-group">
<label class="form-label mt-4" for="confirmnewpw">{{ $t('settings.password.new-confirm-label') }}</label>
<input id="confirmnewpw" v-model="pwFormData.PasswordRepeat" class="form-control" :class="{ 'is-invalid': pwFormData.PasswordRepeat !== ''&& pwFormData.Password !== pwFormData.PasswordRepeat, 'is-valid': pwFormData.PasswordRepeat !== '' && pwFormData.Password === pwFormData.PasswordRepeat && !passwordWeak }" type="password">
<div class="invalid-feedback" v-if="pwFormData.PasswordRepeat !== ''&& pwFormData.Password !== pwFormData.PasswordRepeat">{{ $t('settings.password.invalid-confirm-label') }}</div>
</div>
</div>
</div>
<div class="card border-secondary p-5" v-else>
<h2 class="display-7">{{ $t('settings.api.headline') }}</h2>
<p class="lead">{{ $t('settings.api.abstract') }}</p>
<hr class="my-4">
<p>{{ $t('settings.api.inactive-description') }}</p>
<button class="btn btn-primary" :title="$t('settings.api.button-enable-title')" @click.prevent="profile.enableApi()" :disabled="profile.isFetching">
<i class="fa-solid fa-plus-circle"></i> {{ $t('settings.api.button-enable-text') }}
</button>
<div class="row mt-5">
<div class="col-6">
<button class="btn btn-primary" :title="$t('settings.api.button-disable-title')" @click.prevent="updatePassword" :disabled="profile.isFetching || !passwordChangeAllowed">
<i class="fa-solid fa-floppy-disk"></i> {{ $t('settings.password.change-button-text') }}
</button>
</div>
<div class="col-6">
</div>
</div>
</div>
@ -173,4 +206,53 @@ async function saveRename(credential) {
</div>
</div>
<div class="mt-5" v-if="auth.IsAdmin || !settings.Setting('ApiAdminOnly')">
<div class="card border-secondary p-5" v-if="profile.user.ApiToken">
<h2 class="display-7">{{ $t('settings.api.headline') }}</h2>
<p class="lead">{{ $t('settings.api.abstract') }}</p>
<hr class="my-4">
<p>{{ $t('settings.api.active-description') }}</p>
<div class="row">
<div class="col-6">
<div class="form-group">
<label class="form-label mt-4">{{ $t('settings.api.user-label') }}</label>
<input v-model="profile.user.Identifier" class="form-control" :placeholder="$t('settings.api.user-placeholder')" type="text" readonly>
</div>
</div>
<div class="col-6">
<div class="form-group">
<label class="form-label mt-4">{{ $t('settings.api.token-label') }}</label>
<input v-model="profile.user.ApiToken" class="form-control" :placeholder="$t('settings.api.token-placeholder')" type="text" readonly>
</div>
</div>
</div>
<div class="row">
<div class="col-12">
<div class="form-group">
<p class="form-label mt-4">{{ $t('settings.api.token-created-label') }} {{profile.user.ApiTokenCreated}}</p>
</div>
</div>
</div>
<div class="row mt-5">
<div class="col-6">
<button class="btn btn-primary" :title="$t('settings.api.button-disable-title')" @click.prevent="profile.disableApi()" :disabled="profile.isFetching">
<i class="fa-solid fa-minus-circle"></i> {{ $t('settings.api.button-disable-text') }}
</button>
</div>
<div class="col-6">
<a href="/api/v1/doc.html" target="_blank" :alt="$t('settings.api.api-link')">{{ $t('settings.api.api-link') }}</a>
</div>
</div>
</div>
<div class="card border-secondary p-5" v-else>
<h2 class="display-7">{{ $t('settings.api.headline') }}</h2>
<p class="lead">{{ $t('settings.api.abstract') }}</p>
<hr class="my-4">
<p>{{ $t('settings.api.inactive-description') }}</p>
<button class="btn btn-primary" :title="$t('settings.api.button-enable-title')" @click.prevent="profile.enableApi()" :disabled="profile.isFetching">
<i class="fa-solid fa-plus-circle"></i> {{ $t('settings.api.button-enable-text') }}
</button>
</div>
</div>
</template>

12
go.mod
View File

@ -9,7 +9,7 @@ require (
github.com/glebarez/sqlite v1.11.0
github.com/go-ldap/ldap/v3 v3.4.12
github.com/go-pkgz/routegroup v1.5.3
github.com/go-playground/validator/v10 v10.27.0
github.com/go-playground/validator/v10 v10.28.0
github.com/go-webauthn/webauthn v0.14.0
github.com/google/uuid v1.6.0
github.com/prometheus-community/pro-bing v0.7.0
@ -21,9 +21,9 @@ require (
github.com/xhit/go-simple-mail/v2 v2.16.0
github.com/yeqown/go-qrcode/v2 v2.2.5
github.com/yeqown/go-qrcode/writer/compressed v1.0.1
golang.org/x/crypto v0.42.0
golang.org/x/oauth2 v0.31.0
golang.org/x/sys v0.36.0
golang.org/x/crypto v0.43.0
golang.org/x/oauth2 v0.32.0
golang.org/x/sys v0.37.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.6.0
@ -93,9 +93,9 @@ require (
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9 // indirect
golang.org/x/mod v0.28.0 // indirect
golang.org/x/net v0.44.0 // indirect
golang.org/x/net v0.45.0 // indirect
golang.org/x/sync v0.17.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/tools v0.37.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb // indirect
google.golang.org/protobuf v1.36.10 // indirect

24
go.sum
View File

@ -93,8 +93,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-playground/validator/v10 v10.28.0 h1:Q7ibns33JjyW48gHkuFT91qX48KG0ktULL6FgHdG688=
github.com/go-playground/validator/v10 v10.28.0/go.mod h1:GoI6I1SjPBh9p7ykNE/yj3fFYbyDOpwMn5KXd+m2hUU=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
@ -262,8 +262,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9 h1:TQwNpfvNkxAVlItJf6Cr5JTsVZoC/Sj7K3OZv2Pc14A=
golang.org/x/exp v0.0.0-20251002181428-27f1f14c8bb9/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
@ -291,10 +291,10 @@ golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I=
golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.31.0 h1:8Fq0yVZLh4j4YA47vHKFTa9Ew5XIrCP8LC6UeNZnLxo=
golang.org/x/oauth2 v0.31.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY=
golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -324,8 +324,8 @@ golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -354,8 +354,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=

View File

@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"os/exec"
"slices"
"strings"
"time"
@ -84,8 +85,8 @@ func NewLocalController(cfg *config.Config) (*LocalController, error) {
wg: wg,
nl: nl,
shellCmd: "bash", // we only support bash at the moment
resolvConfIfacePrefix: "tun.", // WireGuard interfaces have a tun. prefix in resolvconf
shellCmd: "bash", // we only support bash at the moment
resolvConfIfacePrefix: cfg.Backend.LocalResolvconfPrefix, // WireGuard interfaces have a tun. prefix in resolvconf
}
return repo, nil
@ -546,7 +547,11 @@ func (c LocalController) deletePeer(deviceId domain.InterfaceIdentifier, id doma
// region wg-quick-related
func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
func (c LocalController) ExecuteInterfaceHook(
_ context.Context,
id domain.InterfaceIdentifier,
hookCmd string,
) error {
if hookCmd == "" {
return nil
}
@ -560,7 +565,7 @@ func (c LocalController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hoo
return nil
}
func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
func (c LocalController) SetDNS(_ context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" {
return nil
}
@ -589,7 +594,7 @@ func (c LocalController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearch
return nil
}
func (c LocalController) UnsetDNS(id domain.InterfaceIdentifier) error {
func (c LocalController) UnsetDNS(_ context.Context, id domain.InterfaceIdentifier, _, _ string) error {
dnsCommand := "resolvconf -d %resPref%i -f"
err := c.exec(dnsCommand, id)
@ -611,7 +616,7 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
if len(stdin) > 0 {
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
if _, err := fmt.Fprint(b, ln+"\n"); err != nil {
return err
}
}
@ -619,6 +624,8 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
}
out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil {
slog.Warn("failed to executed shell command",
"command", commandWithInterfaceName, "stdin", stdin, "output", string(out), "error", err)
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
}
slog.Debug("executed shell command",
@ -631,49 +638,116 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
// region routing-related
func (c LocalController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
// update fwmark rules
if err := c.setFwMarkRules(rules); err != nil {
return err
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
func (c LocalController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error {
interfaceId := info.Interface.Identifier
slog.Debug("setting linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
"cidrs", info.AllowedIps)
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
}
// update main rule
if err := c.setMainRule(rules); err != nil {
return err
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, info.Table, info.FwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
wgDev, err := c.wg.Device(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to get wg device for %s: %w", interfaceId, err)
}
currentFwMark := wgDev.FirewallMark
if int(realFwMark) != currentFwMark {
slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", currentFwMark,
"newFwMark", realFwMark, "oldTable", info.Table, "newTable", realTable)
if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
}
}
// cleanup old main rules
if err := c.cleanupMainRule(rules); err != nil {
return err
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4); err != nil {
return fmt.Errorf("failed to set v4 routes: %w", err)
}
if err := c.setRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6); err != nil {
return fmt.Errorf("failed to set v6 routes: %w", err)
}
return nil
}
func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
for _, rule := range rules {
existingRules, err := c.nl.RuleList(int(rule.IpFamily))
func (c LocalController) setRoutesForFamily(
interfaceId domain.InterfaceIdentifier,
link netlink.Link,
family int,
table int,
fwMark uint32,
cidrs []domain.Cidr,
) error {
// first create or update the routes
for _, cidr := range cidrs {
err := c.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: cidr.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", rule.IpFamily, err)
return fmt.Errorf("failed to add/update route %s on table %d for interface %s: %w",
cidr.String(), table, interfaceId, err)
}
}
ruleExists := false
for _, existingRule := range existingRules {
if rule.FwMark == existingRule.Mark && rule.Table == existingRule.Table {
ruleExists = true
break
// next remove old routes
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
interfaceId, family, err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
if ruleExists {
continue // rule already exists, no need to recreate it
route := domain.CidrFromIpNet(*rawRoute.Dst)
if slices.Contains(cidrs, route) {
continue
}
// create a missing rule
if err := c.nl.RouteDel(&rawRoute); err != nil {
return fmt.Errorf("failed to remove deprecated route %s from interface %s: %w", route, interfaceId, err)
}
}
// next, update route rules for normal routes
if table == 0 {
return nil // no need to update route rules as we are using the default table
}
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family-id %d: %w", family, err)
}
ruleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
return rule.Mark == fwMark && rule.Table == table
})
if !ruleExists {
if err := c.nl.RuleAdd(&netlink.Rule{
Family: int(rule.IpFamily),
Table: rule.Table,
Mark: rule.FwMark,
Family: family,
Table: table,
Mark: fwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
@ -682,15 +756,102 @@ func (c LocalController) setFwMarkRules(rules []domain.RouteRule) error {
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup %s rule for fwmark %d and table %d: %w",
rule.IpFamily, rule.FwMark, rule.Table, err)
return fmt.Errorf("failed to setup rule for fwmark %d and table %d for family-id %d: %w",
fwMark, table, family, err)
}
}
mainRuleExists := slices.ContainsFunc(existingRules, func(rule netlink.Rule) bool {
return rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN
})
if !mainRuleExists && domain.ContainsDefaultRoute(cidrs) {
err = c.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: c.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
})
}
// finally, clean up extra main rules - only one rule is allowed
existingRules, err = c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing main rules for family-id %d: %w", family, err)
}
mainRuleCount := 0
for _, rule := range existingRules {
if rule.SuppressPrefixlen == 0 && rule.Table == unix.RT_TABLE_MAIN {
mainRuleCount++
}
if mainRuleCount > 1 {
if err := c.nl.RuleDel(&rule); err != nil {
return fmt.Errorf("failed to remove extra main rule for family-id %d: %w", family, err)
}
}
}
return nil
}
func (c LocalController) getOrCreateRoutingTableAndFwMark(
link netlink.Link,
tableIn int,
fwMarkIn uint32,
) (
table int,
fwmark uint32,
err error,
) {
table = tableIn
fwmark = fwMarkIn
if fwmark == 0 {
// generate a new (temporary) firewall mark based on the interface index
fwmark = uint32(c.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
}
if table == 0 {
table = int(fwmark) // generate a new routing table base on interface index
}
return
}
func (c LocalController) updateFwMarkOnInterface(interfaceId domain.InterfaceIdentifier, fwMark int) error {
// apply the new fwmark to the wireguard interface
err := c.wg.ConfigureDevice(string(interfaceId), wgtypes.Config{
FirewallMark: &fwMark,
})
if err != nil {
return fmt.Errorf("failed to update fwmark of interface %s to: %d: %w", interfaceId, fwMark, err)
}
return nil
}
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
prio := c.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio++
}
}
return prio
}
func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a priority of 32766
prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
@ -708,126 +869,145 @@ func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
return prio
}
func (c LocalController) setMainRule(rules []domain.RouteRule) error {
var family domain.IpFamily
shouldHaveMainRule := false
for _, rule := range rules {
family = rule.IpFamily
if rule.HasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
return nil
}
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
func (c LocalController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error {
interfaceId := info.Interface.Identifier
slog.Debug("removing linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
"cidrs", info.AllowedIps)
existingRules, err := c.nl.RuleList(int(family))
wgDev, err := c.wg.Device(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
slog.Debug("wg device already removed, route cleanup might be incomplete", "interface", interfaceId)
wgDev = nil
}
ruleExists := false
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
ruleExists = true
break
}
}
if ruleExists {
return nil // rule already exists, skip re-creation
}
if err := c.nl.RuleAdd(&netlink.Rule{
Family: int(family),
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: c.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
return nil
}
func (c LocalController) getMainRulePriority(existingRules []netlink.Rule) int {
priority := c.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == priority {
isFresh = false
break
}
}
if isFresh {
break
} else {
priority++
}
}
return priority
}
func (c LocalController) cleanupMainRule(rules []domain.RouteRule) error {
var family domain.IpFamily
for _, rule := range rules {
family = rule.IpFamily
break
}
existingRules, err := c.nl.RuleList(int(family))
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to get existing rules for family %s: %w", family, err)
slog.Debug("physical link already removed, route cleanup might be incomplete", "interface", interfaceId)
link = nil
}
shouldHaveMainRule := false
for _, rule := range rules {
if rule.HasDefault == true {
shouldHaveMainRule = true
break
fwMark := info.FwMark
if wgDev != nil && info.FwMark == 0 {
fwMark = uint32(wgDev.FirewallMark)
}
table := info.Table
if wgDev != nil && info.Table == 0 {
table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior
}
linkIndex := -1
if link != nil {
linkIndex = link.Attrs().Index
}
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
if linkIndex > 0 {
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4)
if err != nil {
return fmt.Errorf("failed to remove v4 routes: %w", err)
}
err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6)
if err != nil {
return fmt.Errorf("failed to remove v6 routes: %w", err)
}
}
mainRules := 0
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
mainRules++
if table > 0 {
err = c.removeRouteRulesForTable(netlink.FAMILY_V4, realTable)
if err != nil {
return fmt.Errorf("failed to remove v4 route rules for %s: %w", interfaceId, err)
}
}
removalCount := 0
if mainRules > 1 {
removalCount = mainRules - 1 // we only want one single rule
}
if !shouldHaveMainRule {
removalCount = mainRules
}
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
if removalCount > 0 {
existingRule.Family = int(family) // set family, somehow the RuleList method does not populate the family field
if err := c.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete main rule: %w", err)
}
removalCount--
}
err = c.removeRouteRulesForTable(netlink.FAMILY_V6, realTable)
if err != nil {
return fmt.Errorf("failed to remove v6 route rules for %s: %w", interfaceId, err)
}
}
return nil
}
func (c LocalController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
func (c LocalController) removeRoutesForFamily(
interfaceId domain.InterfaceIdentifier,
link netlink.Link,
family int,
table int,
fwMark uint32,
cidrs []domain.Cidr,
) error {
// first remove all rules
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
for _, existingRule := range existingRules {
if fwMark == existingRule.Mark && table == existingRule.Table {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := c.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete old fwmark rule: %w", err)
}
}
}
// next remove all routes
rawRoutes, err := c.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes for interface %s and family-id %d: %w",
interfaceId, family, err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
if rawRoute.Table != table {
continue // ignore routes from other tables
}
route := domain.CidrFromIpNet(*rawRoute.Dst)
if !slices.Contains(cidrs, route) {
continue // only remove routes that were previously added
}
if err := c.nl.RouteDel(&rawRoute); err != nil {
return fmt.Errorf("failed to remove old route %s from interface %s: %w", route, interfaceId, err)
}
}
return nil
}
func (c LocalController) removeRouteRulesForTable(
family int,
table int,
) error {
existingRules, err := c.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing route rules for family-id %d: %w", family, err)
}
for _, existingRule := range existingRules {
if existingRule.Table == table {
err := c.nl.RuleDel(&existingRule)
if err != nil {
return fmt.Errorf("failed to delete old rule for table %d and family-id %d: %w", table, family, err)
}
}
}
return nil
}
// endregion routing-related

View File

@ -15,6 +15,9 @@ import (
"github.com/h44z/wg-portal/internal/lowlevel"
)
const MikrotikRouteDistance = 5
const MikrotikDefaultRoutingTable = "main"
type MikrotikController struct {
coreCfg *config.Config
cfg *config.BackendMikrotik
@ -22,8 +25,9 @@ type MikrotikController struct {
client *lowlevel.MikrotikApiClient
// Add mutexes to prevent race conditions
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
interfaceMutexes sync.Map // map[domain.InterfaceIdentifier]*sync.Mutex
peerMutexes sync.Map // map[domain.PeerIdentifier]*sync.Mutex
coreMutex sync.Mutex // for updating the core configuration such as routing table or DNS settings
}
func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik) (*MikrotikController, error) {
@ -40,6 +44,7 @@ func NewMikrotikController(coreCfg *config.Config, cfg *config.BackendMikrotik)
interfaceMutexes: sync.Map{},
peerMutexes: sync.Map{},
coreMutex: sync.Mutex{},
}, nil
}
@ -763,33 +768,404 @@ func (c *MikrotikController) DeletePeer(
// region wg-quick-related
func (c *MikrotikController) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
func (c *MikrotikController) ExecuteInterfaceHook(
_ context.Context,
_ domain.InterfaceIdentifier,
_ string,
) error {
// TODO implement me
panic("implement me")
slog.Error("interface hooks are not yet supported for Mikrotik backends, please open an issue on GitHub")
return nil
}
func (c *MikrotikController) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
// TODO implement me
panic("implement me")
func (c *MikrotikController) SetDNS(
ctx context.Context,
_ domain.InterfaceIdentifier,
dnsStr, _ string,
) error {
// Lock the interface to prevent concurrent modifications
c.coreMutex.Lock()
defer c.coreMutex.Unlock()
// check if the server is already configured
wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
PropList: []string{"servers"},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
}
var existingServers []string
existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
newServers := strings.Split(dnsStr, ",")
mergedServers := slices.Clone(existingServers)
for _, s := range newServers {
if s == "" {
continue
}
if !slices.Contains(mergedServers, s) {
mergedServers = append(mergedServers, s)
}
}
mergedServersStr := strings.Join(mergedServers, ",")
reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
"servers": mergedServersStr,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
}
return nil
}
func (c *MikrotikController) UnsetDNS(id domain.InterfaceIdentifier) error {
// TODO implement me
panic("implement me")
func (c *MikrotikController) UnsetDNS(
ctx context.Context,
_ domain.InterfaceIdentifier,
dnsStr, _ string,
) error {
// Lock the interface to prevent concurrent modifications
c.coreMutex.Lock()
defer c.coreMutex.Unlock()
// retrieve current DNS settings
wgReply := c.client.Get(ctx, "/ip/dns", &lowlevel.MikrotikRequestOptions{
PropList: []string{"servers"},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard dns settings: %v", wgReply.Error)
}
var existingServers []string
existingServers = append(existingServers, strings.Split(wgReply.Data.GetString("servers"), ",")...)
oldServers := strings.Split(dnsStr, ",")
mergedServers := make([]string, 0, len(existingServers))
for _, s := range existingServers {
if s == "" {
continue
}
if !slices.Contains(oldServers, s) {
mergedServers = append(mergedServers, s) // only keep the servers that are not in the old list
}
}
mergedServersStr := strings.Join(mergedServers, ",")
reply := c.client.ExecList(ctx, "/ip/dns/set", lowlevel.GenericJsonObject{
"servers": mergedServersStr,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to set DNS servers: %s: %v", mergedServersStr, reply.Error)
}
return nil
}
// endregion wg-quick-related
// region routing-related
func (c *MikrotikController) SyncRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
func (c *MikrotikController) SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
interfaceId := info.Interface.Identifier
slog.Debug("setting mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
// Mikrotik needs some time to apply the changes.
// If we don't wait, the routes might get created multiple times as the dynamic routes are not yet available.
time.Sleep(2 * time.Second)
tableName, err := c.getOrCreateRoutingTables(ctx, info.Interface.Identifier, info.TableStr)
if err != nil {
return fmt.Errorf("failed to get or create routing table for %s: %v", interfaceId, err)
}
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
err = c.setRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
if err != nil {
return fmt.Errorf("failed to set IPv4 routes for %s: %v", interfaceId, err)
}
err = c.setRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
if err != nil {
return fmt.Errorf("failed to set IPv6 routes for %s: %v", interfaceId, err)
}
return nil
}
func (c *MikrotikController) DeleteRouteRules(_ context.Context, rules []domain.RouteRule) error {
// TODO implement me
panic("implement me")
func (c *MikrotikController) resolveRouteTableName(name string) string {
name = strings.TrimSpace(name)
var mikrotikTableName string
switch strings.ToLower(name) {
case "", "0":
mikrotikTableName = MikrotikDefaultRoutingTable
case MikrotikDefaultRoutingTable:
return fmt.Sprintf("wgportal-%s",
MikrotikDefaultRoutingTable) // if the Mikrotik Main table should be used, the table-name should be left empty or set to "0".
default:
mikrotikTableName = name
}
return mikrotikTableName
}
func (c *MikrotikController) getOrCreateRoutingTables(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
table string,
) (string, error) {
// retrieve current routing tables
wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "dynamic", "fib", "name",
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return "", fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
}
wantedTableName := c.resolveRouteTableName(table)
// check if the table already exists
for _, table := range wgReply.Data {
if table.GetString("name") == wantedTableName {
return wantedTableName, nil // already exists, nothing to do
}
}
// create the table if it does not exist
createReply := c.client.Create(ctx, "/routing/table", lowlevel.GenericJsonObject{
"name": wantedTableName,
"comment": fmt.Sprintf("Routing Table for %s", interfaceId),
"fib": strconv.FormatBool(true),
})
if createReply.Status != lowlevel.MikrotikApiStatusOk {
return "", fmt.Errorf("failed to create routing table %s: %v", wantedTableName, createReply.Error)
}
return wantedTableName, nil
}
func (c *MikrotikController) setRoutesForFamily(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
ipV6 bool,
table string,
cidrs []domain.Cidr,
) error {
apiPath := "/ip/route"
if ipV6 {
apiPath = "/ipv6/route"
}
// retrieve current routes
wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
"routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
},
Filters: map[string]string{
"gateway": string(interfaceId),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
}
// first create or update the routes
for _, cidr := range cidrs {
// check if the route already exists
exists := false
for _, route := range wgReply.Data {
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
if err != nil {
slog.Warn("failed to parse route destination address",
"cidr", route.GetString("dst-address"), "error", err)
continue
}
if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
exists = true
break
}
}
if exists {
continue // route already exists, nothing to do
}
// create the route
reply := c.client.Create(ctx, apiPath, lowlevel.GenericJsonObject{
"gateway": string(interfaceId),
"dst-address": cidr.String(),
"distance": strconv.Itoa(MikrotikRouteDistance),
"disabled": strconv.FormatBool(false),
"routing-table": table,
})
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to create new route %s via %s: %v", cidr.String(), interfaceId, reply.Error)
}
}
// finally, remove the routes that are not in the new list
for _, route := range wgReply.Data {
if route.GetBool("dynamic") {
continue // dynamic routes are not managed by the controller, nothing to do
}
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
if err != nil {
slog.Warn("failed to parse route destination address",
"cidr", route.GetString("dst-address"), "error", err)
continue
}
valid := false
for _, cidr := range cidrs {
if existingRoute.EqualPrefix(cidr) {
valid = true
break
}
}
if valid {
continue // route is still valid, nothing to do
}
// remove the route
reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to remove outdated route %s: %v", existingRoute.String(), reply.Error)
}
}
return nil
}
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
func (c *MikrotikController) RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
interfaceId := info.Interface.Identifier
slog.Debug("removing mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
tableName := c.resolveRouteTableName(info.TableStr)
cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
err := c.removeRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
if err != nil {
return fmt.Errorf("failed to remove IPv4 routes for %s: %v", interfaceId, err)
}
err = c.removeRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
if err != nil {
return fmt.Errorf("failed to remove IPv6 routes for %s: %v", interfaceId, err)
}
err = c.removeRoutingTable(ctx, tableName)
if err != nil {
return fmt.Errorf("failed to remove routing table for %s: %v", interfaceId, err)
}
return nil
}
func (c *MikrotikController) removeRoutesForFamily(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
ipV6 bool,
table string,
cidrs []domain.Cidr,
) error {
apiPath := "/ip/route"
if ipV6 {
apiPath = "/ipv6/route"
}
// retrieve current routes
wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
"routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
},
Filters: map[string]string{
"gateway": string(interfaceId),
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
}
// remove the routes from the list
for _, route := range wgReply.Data {
if route.GetBool("dynamic") {
continue // dynamic routes are not managed by the controller, nothing to do
}
existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
if err != nil {
slog.Warn("failed to parse route destination address",
"cidr", route.GetString("dst-address"), "error", err)
continue
}
remove := false
for _, cidr := range cidrs {
if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
remove = true
break
}
}
if !remove {
continue // route is still valid, nothing to do
}
// remove the route
reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to remove old route %s: %v", existingRoute.String(), reply.Error)
}
}
return nil
}
func (c *MikrotikController) removeRoutingTable(
ctx context.Context,
table string,
) error {
if table == MikrotikDefaultRoutingTable {
return nil // we cannot remove the default table
}
// retrieve current routing tables
wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
PropList: []string{
".id", "dynamic", "fib", "name",
},
})
if wgReply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
}
for _, existingTable := range wgReply.Data {
if existingTable.GetBool("dynamic") {
continue // dynamic tables are not managed by the controller, nothing to do
}
if existingTable.GetString("name") != table {
continue // not the table we want to remove
}
// remove the table
reply := c.client.Delete(ctx, "/routing/table/"+existingTable.GetString(".id"))
if reply.Status != lowlevel.MikrotikApiStatusOk {
return fmt.Errorf("failed to remove routing table %s: %v", table, reply.Error)
}
return nil
}
return nil
}
// endregion routing-related

View File

@ -1,113 +0,0 @@
package adapters
import (
"bytes"
"fmt"
"log/slog"
"os/exec"
"strings"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/domain"
)
// WgQuickRepo implements higher level wg-quick like interactions like setting DNS, routing tables or interface hooks.
type WgQuickRepo struct {
shellCmd string
resolvConfIfacePrefix string
}
// NewWgQuickRepo creates a new WgQuickRepo instance.
func NewWgQuickRepo() *WgQuickRepo {
return &WgQuickRepo{
shellCmd: "bash",
resolvConfIfacePrefix: "tun.",
}
}
// ExecuteInterfaceHook executes the given hook command.
// The hook command can contain the following placeholders:
//
// %i: the interface identifier.
func (r *WgQuickRepo) ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error {
if hookCmd == "" {
return nil
}
slog.Debug("executing interface hook", "interface", id, "hook", hookCmd)
err := r.exec(hookCmd, id)
if err != nil {
return fmt.Errorf("failed to exec hook: %w", err)
}
return nil
}
// SetDNS sets the DNS settings for the given interface. It uses resolvconf to set the DNS settings.
func (r *WgQuickRepo) SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error {
if dnsStr == "" && dnsSearchStr == "" {
return nil
}
dnsServers := internal.SliceString(dnsStr)
dnsSearchDomains := internal.SliceString(dnsSearchStr)
dnsCommand := "resolvconf -a %resPref%i -m 0 -x"
dnsCommandInput := make([]string, 0, len(dnsServers)+len(dnsSearchDomains))
for _, dnsServer := range dnsServers {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("nameserver %s", dnsServer))
}
for _, searchDomain := range dnsSearchDomains {
dnsCommandInput = append(dnsCommandInput, fmt.Sprintf("search %s", searchDomain))
}
err := r.exec(dnsCommand, id, dnsCommandInput...)
if err != nil {
return fmt.Errorf(
"failed to set dns settings (is resolvconf available?, for systemd create this symlink: ln -s /usr/bin/resolvectl /usr/local/bin/resolvconf): %w",
err,
)
}
return nil
}
// UnsetDNS unsets the DNS settings for the given interface. It uses resolvconf to unset the DNS settings.
func (r *WgQuickRepo) UnsetDNS(id domain.InterfaceIdentifier) error {
dnsCommand := "resolvconf -d %resPref%i -f"
err := r.exec(dnsCommand, id)
if err != nil {
return fmt.Errorf("failed to unset dns settings: %w", err)
}
return nil
}
func (r *WgQuickRepo) replaceCommandPlaceHolders(command string, interfaceId domain.InterfaceIdentifier) string {
command = strings.ReplaceAll(command, "%resPref", r.resolvConfIfacePrefix)
return strings.ReplaceAll(command, "%i", string(interfaceId))
}
func (r *WgQuickRepo) exec(command string, interfaceId domain.InterfaceIdentifier, stdin ...string) error {
commandWithInterfaceName := r.replaceCommandPlaceHolders(command, interfaceId)
cmd := exec.Command(r.shellCmd, "-ce", commandWithInterfaceName)
if len(stdin) > 0 {
b := &bytes.Buffer{}
for _, ln := range stdin {
if _, err := fmt.Fprint(b, ln); err != nil {
return err
}
}
cmd.Stdin = b
}
out, err := cmd.CombinedOutput() // execute and wait for output
if err != nil {
return fmt.Errorf("failed to exexute shell command %s: %w", commandWithInterfaceName, err)
}
slog.Debug("executed shell command",
"command", commandWithInterfaceName,
"output", string(out))
return nil
}

View File

@ -1550,6 +1550,38 @@
}
}
},
"/user/{id}/change-password": {
"post": {
"produces": [
"application/json"
],
"tags": [
"Users"
],
"summary": "Change the password for the given user.",
"operationId": "users_handleChangePasswordPost",
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/model.User"
}
},
"400": {
"description": "Bad Request",
"schema": {
"$ref": "#/definitions/model.Error"
}
},
"500": {
"description": "Internal Server Error",
"schema": {
"$ref": "#/definitions/model.Error"
}
}
}
}
},
"/user/{id}/interfaces": {
"get": {
"produces": [
@ -2159,6 +2191,10 @@
}
]
},
"UserDisplayName": {
"description": "the owner display name",
"type": "string"
},
"UserIdentifier": {
"description": "the owner",
"type": "string"

View File

@ -322,6 +322,9 @@ definitions:
allOf:
- $ref: '#/definitions/model.ConfigOption-string'
description: the routing table
UserDisplayName:
description: the owner display name
type: string
UserIdentifier:
description: the owner
type: string
@ -1442,6 +1445,27 @@ paths:
summary: Enable the REST API for the given user.
tags:
- Users
/user/{id}/change-password:
post:
operationId: users_handleChangePasswordPost
produces:
- application/json
responses:
"200":
description: OK
schema:
$ref: '#/definitions/model.User'
"400":
description: Bad Request
schema:
$ref: '#/definitions/model.Error'
"500":
description: Internal Server Error
schema:
$ref: '#/definitions/model.Error'
summary: Change the password for the given user.
tags:
- Users
/user/{id}/interfaces:
get:
operationId: users_handleInterfacesGet

View File

@ -17,11 +17,6 @@
"paths": {
"/interface/all": {
"get": {
"security": [
{
"BasicAuth": []
}
],
"produces": [
"application/json"
],
@ -52,16 +47,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/interface/by-id/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/interface/by-id/{id}": {
"get": {
"produces": [
"application/json"
],
@ -110,14 +105,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"put": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"put": {
"description": "This endpoint updates an existing interface with the provided data. All required fields must be filled (e.g. name, private key, public key, ...).",
"produces": [
"application/json"
@ -182,14 +177,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"delete": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"delete": {
"produces": [
"application/json"
],
@ -241,16 +236,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/interface/new": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/interface/new": {
"post": {
"description": "This endpoint creates a new interface with the provided data. All required fields must be filled (e.g. name, private key, public key, ...).",
"produces": [
"application/json"
@ -308,16 +303,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/interface/prepare": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/interface/prepare": {
"get": {
"description": "This endpoint returns a new interface with default values (fresh key pair, valid name, new IP address pool, ...).",
"produces": [
"application/json"
@ -352,16 +347,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/metrics/by-interface/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/metrics/by-interface/{id}": {
"get": {
"produces": [
"application/json"
],
@ -410,16 +405,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/metrics/by-peer/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/metrics/by-peer/{id}": {
"get": {
"produces": [
"application/json"
],
@ -468,16 +463,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/metrics/by-user/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/metrics/by-user/{id}": {
"get": {
"produces": [
"application/json"
],
@ -526,16 +521,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/by-id/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/by-id/{id}": {
"get": {
"description": "Normal users can only access their own records. Admins can access all records.",
"produces": [
"application/json"
@ -585,14 +580,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"put": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"put": {
"description": "Only admins can update existing records. The peer record must contain all required fields (e.g., public key, allowed IPs).",
"produces": [
"application/json"
@ -657,14 +652,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"delete": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"delete": {
"produces": [
"application/json"
],
@ -716,16 +711,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/by-interface/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/by-interface/{id}": {
"get": {
"produces": [
"application/json"
],
@ -765,16 +760,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/by-user/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/by-user/{id}": {
"get": {
"description": "Normal users can only access their own records. Admins can access all records.",
"produces": [
"application/json"
@ -815,16 +810,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/new": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/new": {
"post": {
"description": "Only admins can create new records. The peer record must contain all required fields (e.g., public key, allowed IPs).",
"produces": [
"application/json"
@ -882,16 +877,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/peer/prepare/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/peer/prepare/{id}": {
"get": {
"description": "This endpoint is used to prepare a new peer record. The returned data contains a fresh key pair and valid ip address.",
"produces": [
"application/json"
@ -947,16 +942,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/data/peer-config": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/data/peer-config": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"text/plain",
@ -1013,16 +1008,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/data/peer-qr": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/data/peer-qr": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"image/png",
@ -1079,16 +1074,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/data/user-info": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/data/user-info": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"application/json"
@ -1149,16 +1144,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/provisioning/new-peer": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/provisioning/new-peer": {
"post": {
"description": "Normal users can only create new peers if self provisioning is allowed. Admins can always add new peers.",
"produces": [
"application/json"
@ -1216,16 +1211,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/user/all": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/user/all": {
"get": {
"produces": [
"application/json"
],
@ -1256,16 +1251,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/user/by-id/{id}": {
"get": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/user/by-id/{id}": {
"get": {
"description": "Normal users can only access their own record. Admins can access all records.",
"produces": [
"application/json"
@ -1315,14 +1310,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"put": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"put": {
"description": "Only admins can update existing records.",
"produces": [
"application/json"
@ -1387,14 +1382,14 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"delete": {
},
"security": [
{
"BasicAuth": []
}
],
]
},
"delete": {
"produces": [
"application/json"
],
@ -1446,16 +1441,16 @@
"$ref": "#/definitions/models.Error"
}
}
}
}
},
"/user/new": {
"post": {
},
"security": [
{
"BasicAuth": []
}
],
]
}
},
"/user/new": {
"post": {
"description": "Only admins can create new records.",
"produces": [
"application/json"
@ -1513,7 +1508,12 @@
"$ref": "#/definitions/models.Error"
}
}
}
},
"security": [
{
"BasicAuth": []
}
]
}
}
},

View File

@ -2,6 +2,8 @@ package backend
import (
"context"
"fmt"
"strings"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@ -70,6 +72,44 @@ func (u UserService) DeactivateApi(ctx context.Context, id domain.UserIdentifier
return u.users.DeactivateApi(ctx, id)
}
func (u UserService) ChangePassword(ctx context.Context, id domain.UserIdentifier, oldPassword, newPassword string) (*domain.User, error) {
oldPassword = strings.TrimSpace(oldPassword)
newPassword = strings.TrimSpace(newPassword)
if newPassword == "" {
return nil, fmt.Errorf("new password must not be empty")
}
// ensure that the new password is different from the old one
if oldPassword == newPassword {
return nil, fmt.Errorf("new password must be different from the old one")
}
user, err := u.users.GetUser(ctx, id)
if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err)
}
// ensure that the user uses the database backend; otherwise we can't change the password
if user.Source != domain.UserSourceDatabase {
return nil, fmt.Errorf("user source %s does not support password changes", user.Source)
}
// validate old password
if user.CheckPassword(oldPassword) != nil {
return nil, fmt.Errorf("current password is invalid")
}
user.Password = domain.PrivateString(newPassword)
// ensure that the new password is strong enough
if err := user.HasWeakPassword(u.cfg.Auth.MinPasswordLength); err != nil {
return nil, err
}
return u.users.UpdateUser(ctx, user)
}
func (u UserService) GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error) {
return u.wg.GetUserPeers(ctx, id)
}

View File

@ -28,6 +28,8 @@ type UserService interface {
ActivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// DeactivateApi disables the API for the user with the given id.
DeactivateApi(ctx context.Context, id domain.UserIdentifier) (*domain.User, error)
// ChangePassword changes the password for the user with the given id.
ChangePassword(ctx context.Context, id domain.UserIdentifier, oldPassword, newPassword string) (*domain.User, error)
// GetUserPeers returns all peers for the given user.
GetUserPeers(ctx context.Context, id domain.UserIdentifier) ([]domain.Peer, error)
// GetUserPeerStats returns all peer stats for the given user.
@ -75,6 +77,7 @@ func (e UserEndpoint) RegisterRoutes(g *routegroup.Bundle) {
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("GET /{id}/interfaces", e.handleInterfacesGet())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/enable", e.handleApiEnablePost())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/api/disable", e.handleApiDisablePost())
apiGroup.With(e.authenticator.UserIdMatch("id")).HandleFunc("POST /{id}/change-password", e.handleChangePasswordPost())
}
// handleAllGet returns a gorm Handler function.
@ -391,3 +394,68 @@ func (e UserEndpoint) handleApiDisablePost() http.HandlerFunc {
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
}
}
// handleChangePasswordPost returns a gorm Handler function.
//
// @ID users_handleChangePasswordPost
// @Tags Users
// @Summary Change the password for the given user.
// @Produce json
// @Success 200 {object} model.User
// @Failure 400 {object} model.Error
// @Failure 500 {object} model.Error
// @Router /user/{id}/change-password [post]
func (e UserEndpoint) handleChangePasswordPost() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userId := Base64UrlDecode(request.Path(r, "id"))
if userId == "" {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusInternalServerError, Message: "missing id parameter"})
return
}
var passwordData struct {
OldPassword string `json:"OldPassword"`
Password string `json:"Password"`
PasswordRepeat string `json:"PasswordRepeat"`
}
if err := request.BodyJson(r, &passwordData); err != nil {
respond.JSON(w, http.StatusBadRequest, model.Error{Code: http.StatusBadRequest, Message: err.Error()})
return
}
if passwordData.OldPassword == "" {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "old password missing"})
return
}
if passwordData.Password == "" {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "new password missing"})
return
}
if passwordData.OldPassword == passwordData.Password {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "password did not change"})
return
}
if passwordData.Password != passwordData.PasswordRepeat {
respond.JSON(w, http.StatusBadRequest,
model.Error{Code: http.StatusBadRequest, Message: "password mismatch"})
return
}
user, err := e.userService.ChangePassword(r.Context(), domain.UserIdentifier(userId),
passwordData.OldPassword, passwordData.Password)
if err != nil {
respond.JSON(w, http.StatusInternalServerError,
model.Error{Code: http.StatusInternalServerError, Message: err.Error()})
return
}
respond.JSON(w, http.StatusOK, model.NewUser(user, false))
}
}

View File

@ -19,15 +19,16 @@ import (
// PlainOauthAuthenticator is an authenticator that uses OAuth for authentication.
// User information is retrieved from the specified user info endpoint.
type PlainOauthAuthenticator struct {
name string
cfg *oauth2.Config
userInfoEndpoint string
client *http.Client
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
allowedDomains []string
name string
cfg *oauth2.Config
userInfoEndpoint string
client *http.Client
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
sensitiveInfoLogging bool
allowedDomains []string
}
func newPlainOauthAuthenticator(
@ -57,6 +58,7 @@ func newPlainOauthAuthenticator(
provider.userAdminMapping = &cfg.AdminMapping
provider.registrationEnabled = cfg.RegistrationEnabled
provider.userInfoLogging = cfg.LogUserInfo
provider.sensitiveInfoLogging = cfg.LogSensitiveInfo
provider.allowedDomains = cfg.AllowedDomains
return provider, nil
@ -110,6 +112,10 @@ func (p PlainOauthAuthenticator) GetUserInfo(
response, err := p.client.Do(req)
if err != nil {
if p.sensitiveInfoLogging {
slog.Debug("OAuth: failed to get user info", "endpoint", p.userInfoEndpoint,
"token", token, "error", err)
}
return nil, fmt.Errorf("failed to get user info: %w", err)
}
defer internal.LogClose(response.Body)
@ -121,11 +127,15 @@ func (p PlainOauthAuthenticator) GetUserInfo(
var userFields map[string]any
err = json.Unmarshal(contents, &userFields)
if err != nil {
if p.sensitiveInfoLogging {
slog.Debug("OAuth: failed to parse user info", "endpoint", p.userInfoEndpoint,
"token", token, "contents", contents, "error", err)
}
return nil, fmt.Errorf("failed to parse user info: %w", err)
}
if p.userInfoLogging {
slog.Debug("OAuth user info",
slog.Debug("OAuth: user info debug",
"source", p.name,
"info", string(contents))
}

View File

@ -16,15 +16,16 @@ import (
// OidcAuthenticator is an authenticator for OpenID Connect providers.
type OidcAuthenticator struct {
name string
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
cfg *oauth2.Config
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
allowedDomains []string
name string
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
cfg *oauth2.Config
userInfoMapping config.OauthFields
userAdminMapping *config.OauthAdminMapping
registrationEnabled bool
userInfoLogging bool
sensitiveInfoLogging bool
allowedDomains []string
}
func newOidcAuthenticator(
@ -58,6 +59,7 @@ func newOidcAuthenticator(
provider.userAdminMapping = &cfg.AdminMapping
provider.registrationEnabled = cfg.RegistrationEnabled
provider.userInfoLogging = cfg.LogUserInfo
provider.sensitiveInfoLogging = cfg.LogSensitiveInfo
provider.allowedDomains = cfg.AllowedDomains
return provider, nil
@ -102,24 +104,40 @@ func (o OidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token,
) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: token does not contain id_token", "token", token, "nonce", nonce)
}
return nil, errors.New("token does not contain id_token")
}
idToken, err := o.verifier.Verify(ctx, rawIDToken)
if err != nil {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: failed to validate id_token", "token", token, "id_token", rawIDToken, "nonce", nonce,
"error",
err)
}
return nil, fmt.Errorf("failed to validate id_token: %w", err)
}
if idToken.Nonce != nonce {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: id_token nonce mismatch", "token", token, "id_token", idToken, "nonce", nonce)
}
return nil, errors.New("nonce mismatch")
}
var tokenFields map[string]any
if err = idToken.Claims(&tokenFields); err != nil {
if o.sensitiveInfoLogging {
slog.Debug("OIDC: failed to parse extra claims", "token", token, "id_token", idToken, "nonce", nonce,
"error",
err)
}
return nil, fmt.Errorf("failed to parse extra claims: %w", err)
}
if o.userInfoLogging {
contents, _ := json.Marshal(tokenFields)
slog.Debug("OIDC user info",
slog.Debug("OIDC: user info debug",
"source", o.name,
"info", string(contents))
}

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"log/slog"
"net/mail"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
@ -101,29 +102,15 @@ func (m Manager) SendPeerEmail(ctx context.Context, linkOnly bool, style string,
}
if peer.UserIdentifier == "" {
slog.Debug("skipping peer email",
"peer", peerId,
"reason", "no user linked")
continue
return fmt.Errorf("peer %s has no user linked, no email is sent", peerId)
}
user, err := m.users.GetUser(ctx, peer.UserIdentifier)
if err != nil {
slog.Debug("skipping peer email",
"peer", peerId,
"reason", "unable to fetch user",
"error", err)
continue
email, user := m.resolveEmail(ctx, peer)
if email == "" {
return fmt.Errorf("peer %s has no valid email address, no email is sent", peerId)
}
if user.Email == "" {
slog.Debug("skipping peer email",
"peer", peerId,
"reason", "user has no mail address")
continue
}
err = m.sendPeerEmail(ctx, linkOnly, style, user, peer)
err = m.sendPeerEmail(ctx, linkOnly, style, &user, peer)
if err != nil {
return fmt.Errorf("failed to send peer email for %s: %w", peerId, err)
}
@ -194,3 +181,37 @@ func (m Manager) sendPeerEmail(
return nil
}
func (m Manager) resolveEmail(ctx context.Context, peer *domain.Peer) (string, domain.User) {
user, err := m.users.GetUser(ctx, peer.UserIdentifier)
if err != nil {
if m.cfg.Mail.AllowPeerEmail {
_, err := mail.ParseAddress(string(peer.UserIdentifier)) // test if the user identifier is a valid email address
if err == nil {
slog.Debug("peer email: using user-identifier as email",
"peer", peer.Identifier, "email", peer.UserIdentifier)
return string(peer.UserIdentifier), domain.User{}
} else {
slog.Debug("peer email: skipping peer email",
"peer", peer.Identifier,
"reason", "peer has no user linked and user-identifier is not a valid email address")
return "", domain.User{}
}
} else {
slog.Debug("peer email: skipping peer email",
"peer", peer.Identifier,
"reason", "user has no user linked")
return "", domain.User{}
}
}
if user.Email == "" {
slog.Debug("peer email: skipping peer email",
"peer", peer.Identifier,
"reason", "user has no mail address")
return "", domain.User{}
}
slog.Debug("peer email: using user email", "peer", peer.Identifier, "email", user.Email)
return user.Email, *user
}

View File

@ -4,25 +4,23 @@ import (
"context"
"fmt"
"log/slog"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"sync"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
"github.com/h44z/wg-portal/internal/domain"
"github.com/h44z/wg-portal/internal/lowlevel"
)
// region dependencies
type ControllerManager interface {
// GetController returns the controller for the given interface.
GetController(iface domain.Interface) domain.InterfaceController
}
type InterfaceAndPeerDatabaseRepo interface {
// GetAllInterfaces returns all interfaces
GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
// GetInterfacePeers returns all peers for a given interface
GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
// GetInterface returns the interface with the given identifier.
GetInterface(ctx context.Context, id domain.InterfaceIdentifier) (*domain.Interface, error)
}
type EventBus interface {
@ -30,6 +28,13 @@ type EventBus interface {
Subscribe(topic string, fn interface{}) error
}
type RoutesController interface {
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error
}
// endregion dependencies
type routeRuleInfo struct {
@ -45,28 +50,27 @@ type routeRuleInfo struct {
type Manager struct {
cfg *config.Config
bus EventBus
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
db InterfaceAndPeerDatabaseRepo
bus EventBus
db InterfaceAndPeerDatabaseRepo
wgController ControllerManager
mux *sync.Mutex
}
// NewRouteManager creates a new route manager instance.
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
wg, err := wgctrl.New()
if err != nil {
panic("failed to init wgctrl: " + err.Error())
}
nl := &lowlevel.NetlinkManager{}
func NewRouteManager(
cfg *config.Config,
bus EventBus,
db InterfaceAndPeerDatabaseRepo,
wgController ControllerManager,
) (*Manager, error) {
m := &Manager{
cfg: cfg,
bus: bus,
db: db,
wg: wg,
nl: nl,
db: db,
wgController: wgController,
mux: &sync.Mutex{},
}
m.connectToMessageBus()
@ -85,419 +89,82 @@ func (m Manager) StartBackgroundJobs(_ context.Context) {
// this is a no-op for now
}
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
slog.Debug("handling route update event", "source", srcDescription)
func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
m.mux.Lock() // ensure that only one route update is processed at a time
defer m.mux.Unlock()
err := m.syncRoutes(context.Background())
if err != nil {
slog.Error("failed to synchronize routes",
"source", srcDescription,
"error", err)
slog.Debug("handling route update event", "info", info.String())
if !info.ManagementEnabled() {
return // route management disabled
}
slog.Debug("routes synchronized", "source", srcDescription)
err := m.syncRoutes(context.Background(), info)
if err != nil {
slog.Error("failed to synchronize routes",
"info", info.String(), "error", err)
return
}
slog.Debug("routes synchronized", "info", info.String())
}
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
m.mux.Lock() // ensure that only one route update is processed at a time
defer m.mux.Unlock()
slog.Debug("handling route remove event", "info", info.String())
if !info.ManagementEnabled() {
return // route management disabled
}
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
slog.Error("failed to remove v4 fwmark rules", "error", err)
}
if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
slog.Error("failed to remove v6 fwmark rules", "error", err)
}
slog.Debug("routes removed", "table", info.String())
}
func (m Manager) syncRoutes(ctx context.Context) error {
interfaces, err := m.db.GetAllInterfaces(ctx)
err := m.removeRoutes(context.Background(), info)
if err != nil {
return fmt.Errorf("failed to find all interfaces: %w", err)
slog.Error("failed to synchronize routes",
"info", info.String(), "error", err)
return
}
rules := map[int][]routeRuleInfo{
netlink.FAMILY_V4: nil,
netlink.FAMILY_V6: nil,
}
for _, iface := range interfaces {
if iface.IsDisabled() {
continue // disabled interface does not need route entries
}
if !iface.ManageRoutingTable() {
continue
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
}
allowedIPs := iface.GetAllowedIPs(peers)
defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
link, err := m.nl.LinkByName(string(iface.Identifier))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
}
table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link)
if err != nil {
return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
}
if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
}
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
}
if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
}
if table != 0 {
rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V4,
hasDefault: defRouteV4,
})
}
if table != 0 {
rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
ifaceId: iface.Identifier,
fwMark: fwmark,
table: table,
family: netlink.FAMILY_V6,
hasDefault: defRouteV6,
})
}
}
return m.syncRouteRules(rules)
slog.Debug("routes removed", "info", info.String())
}
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
for family, rules := range allRules {
// update fwmark rules
if err := m.setFwMarkRules(rules, family); err != nil {
return err
}
// update main rule
if err := m.setMainRule(rules, family); err != nil {
return err
}
// cleanup old main rules
if err := m.cleanupMainRule(rules, family); err != nil {
return err
}
}
return nil
}
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
for _, rule := range rules {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
ruleExists = true
break
}
}
if ruleExists {
continue // rule already exists, no need to recreate it
}
// create missing rule
if err := m.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: rule.table,
Mark: rule.fwMark,
Invert: true,
SuppressIfgroup: -1,
SuppressPrefixlen: -1,
Priority: m.getRulePriority(existingRules),
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
}
}
return nil
}
func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error {
existingRules, err := m.nl.RuleList(family)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
for _, existingRule := range existingRules {
if fwmark == existingRule.Mark && table == existingRule.Table {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := m.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete fwmark rule: %w", err)
}
}
}
return nil
}
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
if !shouldHaveMainRule {
func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
if !ok {
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
}
existingRules, err := m.nl.RuleList(family)
if !info.Interface.ManageRoutingTable() {
slog.Debug("interface does not manage routing table, skipping route update",
"interface", info.Interface.Identifier)
return nil
}
err := rc.SetRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
}
ruleExists := false
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
ruleExists = true
break
}
}
if ruleExists {
return nil // rule already exists, skip re-creation
}
if err := m.nl.RuleAdd(&netlink.Rule{
Family: family,
Table: unix.RT_TABLE_MAIN,
SuppressIfgroup: -1,
SuppressPrefixlen: 0,
Priority: m.getMainRulePriority(existingRules),
Mark: 0,
Mask: nil,
Goto: -1,
Flow: -1,
}); err != nil {
return fmt.Errorf("failed to setup rule for main table: %w", err)
}
return nil
}
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
existingRules, err := m.nl.RuleList(family)
func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
rc, ok := m.wgController.GetController(info.Interface).(RoutesController)
if !ok {
slog.Warn("no capable routes-controller found for interface", "interface", info.Interface.Identifier)
return nil
}
if !info.Interface.ManageRoutingTable() {
slog.Debug("interface does not manage routing table, skipping route removal",
"interface", info.Interface.Identifier)
return nil
}
err := rc.RemoveRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
}
shouldHaveMainRule := false
for _, rule := range rules {
if rule.hasDefault == true {
shouldHaveMainRule = true
break
}
}
mainRules := 0
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
mainRules++
}
}
removalCount := 0
if mainRules > 1 {
removalCount = mainRules - 1 // we only want one single rule
}
if !shouldHaveMainRule {
removalCount = mainRules
}
for _, existingRule := range existingRules {
if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
if removalCount > 0 {
existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
if err := m.nl.RuleDel(&existingRule); err != nil {
return fmt.Errorf("failed to delete main rule: %w", err)
}
removalCount--
}
}
}
return nil
}
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
prio := m.cfg.Advanced.RulePrioOffset
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio++
}
}
return prio
}
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
prio := 32700 // linux main rule has a prio of 32766
for {
isFresh := true
for _, existingRule := range existingRules {
if existingRule.Priority == prio {
isFresh = false
break
}
}
if isFresh {
break
} else {
prio--
}
}
return prio
}
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
for _, allowedIP := range allowedIPs {
err := m.nl.RouteReplace(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: allowedIP.IpNet(),
Table: table,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
})
if err != nil {
return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
}
}
return nil
}
func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
LinkIndex: link.Attrs().Index,
Table: unix.RT_TABLE_UNSPEC, // all tables
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
if err != nil {
return fmt.Errorf("failed to fetch raw routes: %w", err)
}
for _, rawRoute := range rawRoutes {
if rawRoute.Dst == nil { // handle default route
var netlinkAddr domain.Cidr
if family == netlink.FAMILY_V4 {
netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
} else {
netlinkAddr, _ = domain.CidrFromString("::/0")
}
rawRoute.Dst = netlinkAddr.IpNet()
}
netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
remove := true
for _, allowedIP := range allowedIPs {
if netlinkAddr == allowedIP {
remove = false
break
}
}
if !remove {
continue
}
err := m.nl.RouteDel(&rawRoute)
if err != nil {
return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
}
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
}
return nil
}
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) (
table int,
fwmark uint32,
err error,
) {
table = iface.GetRoutingTable()
fwmark = iface.FirewallMark
if fwmark == 0 {
// generate a new (temporary) firewall mark based on the interface index
fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
slog.Debug("using fwmark to handle routes",
"interface", iface.Identifier,
"fwmark", fwmark)
// apply the temporary fwmark to the wireguard interface
err = m.setFwMark(iface.Identifier, int(fwmark))
}
if table == 0 {
table = int(fwmark) // generate a new routing table base on interface index
slog.Debug("using routing table to handle default routes",
"interface", iface.Identifier,
"table", table)
}
return
}
func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error {
err := m.wg.ConfigureDevice(string(id), wgtypes.Config{
FirewallMark: &fwmark,
})
if err != nil {
return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err)
}
return nil
}
func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) {
for _, allowedIP := range allowedIPs {
if ipV4 && ipV6 {
break // speed up
}
if allowedIP.Prefix().Bits() == 0 {
if allowedIP.IsV4() {
ipV4 = true
} else {
ipV6 = true
}
}
}
return
}

View File

@ -1,7 +1,6 @@
package wireguard
import (
"context"
"fmt"
"log/slog"
"maps"
@ -12,33 +11,9 @@ import (
"github.com/h44z/wg-portal/internal/domain"
)
type InterfaceController interface {
GetId() domain.InterfaceBackend
GetInterfaces(_ context.Context) ([]domain.PhysicalInterface, error)
GetInterface(_ context.Context, id domain.InterfaceIdentifier) (*domain.PhysicalInterface, error)
GetPeers(_ context.Context, deviceId domain.InterfaceIdentifier) ([]domain.PhysicalPeer, error)
SaveInterface(
_ context.Context,
id domain.InterfaceIdentifier,
updateFunc func(pi *domain.PhysicalInterface) (*domain.PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id domain.InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId domain.InterfaceIdentifier,
id domain.PeerIdentifier,
updateFunc func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId domain.InterfaceIdentifier, id domain.PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*domain.PingerResult, error)
}
type backendInstance struct {
Config config.BackendBase // Config is the configuration for the backend instance.
Implementation InterfaceController
Implementation domain.InterfaceController
}
type ControllerManager struct {
@ -118,11 +93,11 @@ func (c *ControllerManager) logRegisteredControllers() {
}
}
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) InterfaceController {
func (c *ControllerManager) GetControllerByName(backend domain.InterfaceBackend) domain.InterfaceController {
return c.getController(backend, "").Implementation
}
func (c *ControllerManager) GetController(iface domain.Interface) InterfaceController {
func (c *ControllerManager) GetController(iface domain.Interface) domain.InterfaceController {
return c.getController(iface.Backend, iface.Identifier).Implementation
}

View File

@ -38,9 +38,9 @@ type InterfaceAndPeerDatabaseRepo interface {
}
type WgQuickController interface {
ExecuteInterfaceHook(id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(id domain.InterfaceIdentifier) error
ExecuteInterfaceHook(ctx context.Context, id domain.InterfaceIdentifier, hookCmd string) error
SetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
UnsetDNS(ctx context.Context, id domain.InterfaceIdentifier, dnsStr, dnsSearchStr string) error
}
type EventBus interface {
@ -53,11 +53,10 @@ type EventBus interface {
// endregion dependencies
type Manager struct {
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg *ControllerManager
quick WgQuickController
cfg *config.Config
bus EventBus
db InterfaceAndPeerDatabaseRepo
wg *ControllerManager
userLockMap *sync.Map
}
@ -66,7 +65,6 @@ func NewWireGuardManager(
cfg *config.Config,
bus EventBus,
wg *ControllerManager,
quick WgQuickController,
db InterfaceAndPeerDatabaseRepo,
) (*Manager, error) {
m := &Manager{
@ -74,7 +72,6 @@ func NewWireGuardManager(
bus: bus,
wg: wg,
db: db,
quick: quick,
userLockMap: &sync.Map{},
}

View File

@ -453,7 +453,7 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return err
}
existingInterface, err := m.db.GetInterface(ctx, id)
existingInterface, existingPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", id, err)
}
@ -462,21 +462,29 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion not allowed: %w", err)
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *existingInterface,
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
FwMark: existingInterface.FirewallMark,
Table: existingInterface.GetRoutingTable(),
TableStr: existingInterface.RoutingTable,
IsDeleted: true,
})
now := time.Now()
existingInterface.Disabled = &now // simulate a disabled interface
existingInterface.DisabledReason = domain.DisabledReasonDeleted
physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id)
if err := m.handleInterfacePreSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(),
false); err != nil {
return fmt.Errorf("pre-delete hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(existingInterface); err != nil {
if err := m.handleInterfacePreSaveActions(ctx, existingInterface); err != nil {
return fmt.Errorf("pre-delete actions failed: %w", err)
}
if err := m.deleteInterfacePeers(ctx, id); err != nil {
if err := m.deleteInterfacePeers(ctx, existingInterface, existingPeers); err != nil {
return fmt.Errorf("peer deletion failure: %w", err)
}
@ -488,16 +496,12 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
return fmt.Errorf("deletion failure: %w", err)
}
fwMark := existingInterface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: existingInterface.GetRoutingTable(),
})
if err := m.handleInterfacePostSaveHooks(existingInterface, !existingInterface.IsDisabled(), false); err != nil {
if err := m.handleInterfacePostSaveHooks(
ctx,
existingInterface,
!existingInterface.IsDisabled(),
false,
); err != nil {
return fmt.Errorf("post-delete hooks failed: %w", err)
}
@ -516,17 +520,21 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("interface validation failed: %w", err)
}
oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface)
oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err == nil {
oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
}
if err := m.handleInterfacePreSaveHooks(iface, oldEnabled, newEnabled); err != nil {
if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
}
if err := m.handleInterfacePreSaveActions(iface); err != nil {
if err := m.handleInterfacePreSaveActions(ctx, iface); err != nil {
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
@ -569,20 +577,35 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
}
if iface.IsDisabled() {
physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier)
fwMark := iface.FirewallMark
if physicalInterface != nil && fwMark == 0 {
fwMark = physicalInterface.FirewallMark
}
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
FwMark: fwMark,
Table: iface.GetRoutingTable(),
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
} else {
m.bus.Publish(app.TopicRouteUpdate, "interface updated: "+string(iface.Identifier))
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// if the route table changed, ensure that the old entries are remove
if routeTableChanged {
m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
Interface: *oldInterface,
AllowedIps: oldInterface.GetAllowedIPs(peers),
FwMark: oldInterface.FirewallMark,
Table: oldInterface.GetRoutingTable(),
TableStr: oldInterface.RoutingTable,
IsDeleted: true, // mark the old entries as deleted
})
}
}
if err := m.handleInterfacePostSaveHooks(iface, oldEnabled, newEnabled); err != nil {
if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("post-save hooks failed: %w", err)
}
@ -618,60 +641,90 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return iface, nil
}
func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Interface) (oldEnabled, newEnabled bool) {
oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
if err != nil {
return false, !iface.IsDisabled() // if the interface did not exist, we assume it was not enabled
}
return !oldInterface.IsDisabled(), !iface.IsDisabled()
func (m Manager) getInterfaceStateHistory(
oldInterface *domain.Interface,
iface *domain.Interface,
) (oldEnabled, newEnabled, routeTableChanged bool) {
return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
}
func (m Manager) handleInterfacePreSaveActions(iface *domain.Interface) error {
if !iface.IsDisabled() {
if err := m.quick.SetDNS(iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := m.quick.UnsetDNS(iface.Identifier); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to perform pre-save actions", "interface", iface.Identifier,
"error", "no capable controller found")
return nil
}
// update DNS settings only for client interfaces
if iface.Type == domain.InterfaceTypeClient || iface.Type == domain.InterfaceTypeAny {
if !iface.IsDisabled() {
if err := wgQuickController.SetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to update dns settings: %w", err)
}
} else {
if err := wgQuickController.UnsetDNS(ctx, iface.Identifier, iface.DnsStr, iface.DnsSearchStr); err != nil {
return fmt.Errorf("failed to clear dns settings: %w", err)
}
}
}
return nil
}
func (m Manager) handleInterfacePreSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error {
func (m Manager) handleInterfacePreSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
slog.Debug("executing pre-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute pre-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreUp); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreUp); err != nil {
return fmt.Errorf("failed to execute pre-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PreDown); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PreDown); err != nil {
return fmt.Errorf("failed to execute pre-down hook: %w", err)
}
}
return nil
}
func (m Manager) handleInterfacePostSaveHooks(iface *domain.Interface, oldEnabled, newEnabled bool) error {
func (m Manager) handleInterfacePostSaveHooks(
ctx context.Context,
iface *domain.Interface,
oldEnabled, newEnabled bool,
) error {
if oldEnabled == newEnabled {
return nil // do nothing if state did not change
}
slog.Debug("executing post-save hooks", "interface", iface.Identifier, "up", newEnabled)
wgQuickController, ok := m.wg.GetController(*iface).(WgQuickController)
if !ok {
slog.Warn("failed to execute post-save hooks", "interface", iface.Identifier, "up", newEnabled,
"error", "no capable controller found")
return nil
}
if newEnabled {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostUp); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostUp); err != nil {
return fmt.Errorf("failed to execute post-up hook: %w", err)
}
} else {
if err := m.quick.ExecuteInterfaceHook(iface.Identifier, iface.PostDown); err != nil {
if err := wgQuickController.ExecuteInterfaceHook(ctx, iface.Identifier, iface.PostDown); err != nil {
return fmt.Errorf("failed to execute post-down hook: %w", err)
}
}
@ -799,7 +852,7 @@ func (m Manager) getFreshListenPort(ctx context.Context) (port int, err error) {
func (m Manager) importInterface(
ctx context.Context,
backend InterfaceController,
backend domain.InterfaceController,
in *domain.PhysicalInterface,
peers []domain.PhysicalPeer,
) error {
@ -901,13 +954,9 @@ func (m Manager) importPeer(ctx context.Context, in *domain.Interface, p *domain
return nil
}
func (m Manager) deleteInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) error {
iface, allPeers, err := m.db.GetInterfaceAndPeers(ctx, id)
if err != nil {
return err
}
func (m Manager) deleteInterfacePeers(ctx context.Context, iface *domain.Interface, allPeers []domain.Peer) error {
for _, peer := range allPeers {
err = m.wg.GetController(*iface).DeletePeer(ctx, id, peer.Identifier)
err := m.wg.GetController(*iface).DeletePeer(ctx, iface.Identifier, peer.Identifier)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("wireguard peer deletion failure for %s: %w", peer.Identifier, err)
}

View File

@ -388,9 +388,20 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
return fmt.Errorf("failed to delete peer %s: %w", id, err)
}
peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
if err != nil {
return fmt.Errorf("failed to load peers for interface %s: %w", iface.Identifier, err)
}
m.bus.Publish(app.TopicPeerDeleted, *peer)
// Update routes after peers have changed
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: *iface,
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
// Update interface after peers have changed
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
@ -438,20 +449,26 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier)
// region helper-functions
func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
interfaces := make(map[domain.InterfaceIdentifier]struct{})
interfaces := make(map[domain.InterfaceIdentifier]domain.Interface)
for _, peer := range peers {
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
// get interface from db if it is not yet in the map
if _, ok := interfaces[peer.InterfaceIdentifier]; !ok {
iface, err := m.db.GetInterface(ctx, peer.InterfaceIdentifier)
if err != nil {
return fmt.Errorf("unable to find interface %s: %w", peer.InterfaceIdentifier, err)
}
interfaces[peer.InterfaceIdentifier] = *iface
}
iface := interfaces[peer.InterfaceIdentifier]
// Always save the peer to the backend, regardless of disabled/expired state
// The backend will handle the disabled state appropriately
err = m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
err := m.db.SavePeer(ctx, peer.Identifier, func(p *domain.Peer) (*domain.Peer, error) {
peer.CopyCalculatedAttributes(p)
err := m.wg.GetController(*iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
err := m.wg.GetController(iface).SavePeer(ctx, peer.InterfaceIdentifier, peer.Identifier,
func(pp *domain.PhysicalPeer) (*domain.PhysicalPeer, error) {
domain.MergeToPhysicalPeer(pp, peer)
return pp, nil
@ -475,13 +492,22 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
Peer: *peer,
},
})
interfaces[peer.InterfaceIdentifier] = struct{}{}
}
// Update routes after peers have changed
if len(interfaces) != 0 {
m.bus.Publish(app.TopicRouteUpdate, "peers updated")
for id, iface := range interfaces {
interfacePeers, err := m.db.GetInterfacePeers(ctx, id)
if err != nil {
return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err)
}
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
Interface: iface,
AllowedIps: iface.GetAllowedIPs(interfacePeers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
TableStr: iface.RoutingTable,
})
}
for iface := range interfaces {

View File

@ -211,6 +211,10 @@ type OpenIDConnectProvider struct {
// If LogUserInfo is set to true, the user info retrieved from the OIDC provider will be logged in trace level.
LogUserInfo bool `yaml:"log_user_info"`
// If LogSensitiveInfo is set to true, sensitive information retrieved from the OIDC provider will be logged in trace level.
// This also includes OAuth tokens! Keep this disabled in production!
LogSensitiveInfo bool `yaml:"log_sensitive_info"`
}
// OAuthProvider contains the configuration for the OAuth provider.
@ -252,6 +256,10 @@ type OAuthProvider struct {
// If LogUserInfo is set to true, the user info retrieved from the OAuth provider will be logged in trace level.
LogUserInfo bool `yaml:"log_user_info"`
// If LogSensitiveInfo is set to true, sensitive information retrieved from the OAuth provider will be logged in trace level.
// This also includes OAuth tokens! Keep this disabled in production!
LogSensitiveInfo bool `yaml:"log_sensitive_info"`
}
// WebauthnConfig contains the configuration for the WebAuthn authenticator.

View File

@ -13,6 +13,7 @@ type Backend struct {
// Local Backend-specific configuration
IgnoredLocalInterfaces []string `yaml:"ignored_local_interfaces"` // A list of interface names that should be ignored by this backend (e.g., "wg0")
LocalResolvconfPrefix string `yaml:"local_resolvconf_prefix"` // The prefix to use for interface names when passing them to resolvconf.
// External Backend-specific configuration

View File

@ -134,6 +134,9 @@ func defaultConfig() *Config {
cfg.Backend = Backend{
Default: LocalBackendName, // local backend is the default (using wgcrtl)
// Most resolconf implementations use "tun." as a prefix for interface names.
// But systemd's implementation uses no prefix, for example.
LocalResolvconfPrefix: "tun.",
}
cfg.Web = WebConfig{

View File

@ -41,4 +41,6 @@ type MailConfig struct {
From string `yaml:"from"`
// LinkOnly specifies whether emails should only contain a link to WireGuard Portal or attach the full configuration
LinkOnly bool `yaml:"link_only"`
// AllowPeerEmail specifies whether emails should be sent to peers which have no valid user account linked, but an email address is set as "user".
AllowPeerEmail bool `yaml:"allow_peer_email"`
}

View File

@ -13,6 +13,7 @@ import (
"golang.org/x/sys/unix"
"github.com/h44z/wg-portal/internal"
"github.com/h44z/wg-portal/internal/config"
)
const (
@ -132,17 +133,30 @@ func (i *Interface) GetConfigFileName() string {
return filename
}
// GetAllowedIPs returns the allowed IPs for the interface depending on the interface type and peers.
// For example, if the interface type is Server, the allowed IPs are the IPs of the peers.
// If the interface type is Client, the allowed IPs correspond to the AllowedIPsStr of the peers.
func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr {
var allowedCidrs []Cidr
for _, peer := range peers {
for _, ip := range peer.Interface.Addresses {
allowedCidrs = append(allowedCidrs, ip.HostAddr())
switch i.Type {
case InterfaceTypeServer, InterfaceTypeAny:
for _, peer := range peers {
for _, ip := range peer.Interface.Addresses {
allowedCidrs = append(allowedCidrs, ip.HostAddr())
}
if peer.ExtraAllowedIPsStr != "" {
extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
if err == nil {
allowedCidrs = append(allowedCidrs, extraIPs...)
}
}
}
if peer.ExtraAllowedIPsStr != "" {
extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr)
case InterfaceTypeClient:
for _, peer := range peers {
allowedIPs, err := CidrsFromString(peer.AllowedIPsStr.GetValue())
if err == nil {
allowedCidrs = append(allowedCidrs, extraIPs...)
allowedCidrs = append(allowedCidrs, allowedIPs...)
}
}
}
@ -159,6 +173,7 @@ func (i *Interface) ManageRoutingTable() bool {
//
// -1 if RoutingTable was set to "off" or an error occurred
func (i *Interface) GetRoutingTable() int {
routingTableStr := strings.ToLower(i.RoutingTable)
switch {
case routingTableStr == "":
@ -166,6 +181,9 @@ func (i *Interface) GetRoutingTable() int {
case routingTableStr == "off":
return -1
case strings.HasPrefix(routingTableStr, "0x"):
if i.Backend != config.LocalBackendName {
return 0 // ignore numeric routing table numbers for non-local controllers
}
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
if err != nil {
@ -178,6 +196,9 @@ func (i *Interface) GetRoutingTable() int {
}
return int(routingTable)
default:
if i.Backend != config.LocalBackendName {
return 0 // ignore numeric routing table numbers for non-local controllers
}
routingTable, err := strconv.Atoi(routingTableStr)
if err != nil {
slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err)
@ -308,12 +329,18 @@ func MergeToPhysicalInterface(pi *PhysicalInterface, i *Interface) {
}
type RoutingTableInfo struct {
FwMark uint32
Table int
Interface Interface
AllowedIps []Cidr
FwMark uint32
Table int
TableStr string // the routing table number as string (used by mikrotik, linux uses the numeric value)
IsDeleted bool // true if the interface was deleted, false otherwise
}
func (r RoutingTableInfo) String() string {
return fmt.Sprintf("%d -> %d", r.FwMark, r.Table)
v4, v6 := CidrsPerFamily(r.AllowedIps)
return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table,
len(v4), len(v6))
}
func (r RoutingTableInfo) ManagementEnabled() bool {

View File

@ -0,0 +1,27 @@
package domain
import "context"
type InterfaceController interface {
GetId() InterfaceBackend
GetInterfaces(_ context.Context) ([]PhysicalInterface, error)
GetInterface(_ context.Context, id InterfaceIdentifier) (*PhysicalInterface, error)
GetPeers(_ context.Context, deviceId InterfaceIdentifier) ([]PhysicalPeer, error)
SaveInterface(
_ context.Context,
id InterfaceIdentifier,
updateFunc func(pi *PhysicalInterface) (*PhysicalInterface, error),
) error
DeleteInterface(_ context.Context, id InterfaceIdentifier) error
SavePeer(
_ context.Context,
deviceId InterfaceIdentifier,
id PeerIdentifier,
updateFunc func(pp *PhysicalPeer) (*PhysicalPeer, error),
) error
DeletePeer(_ context.Context, deviceId InterfaceIdentifier, id PeerIdentifier) error
PingAddresses(
ctx context.Context,
addr string,
) (*PingerResult, error)
}

View File

@ -5,6 +5,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/h44z/wg-portal/internal/config"
)
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
@ -37,8 +39,9 @@ func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
assert.Equal(t, expected, iface.GetConfigFileName())
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
func TestInterface_GetAllowedIPsReturnsCorrectCidrsServerMode(t *testing.T) {
peer1 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
@ -46,16 +49,45 @@ func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
},
}
peer2 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{}
iface := &Interface{Type: InterfaceTypeServer}
expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
{Cidr: "10.20.2.2/32", Addr: "10.20.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
func TestInterface_GetAllowedIPsReturnsCorrectCidrsClientMode(t *testing.T) {
peer1 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
},
},
}
peer2 := Peer{
AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
iface := &Interface{Type: InterfaceTypeClient}
expected := []Cidr{
{Cidr: "192.168.2.2/32", Addr: "192.168.2.2", NetLength: 32},
{Cidr: "10.0.2.2/32", Addr: "10.0.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
@ -66,10 +98,22 @@ func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
iface = &Interface{RoutingTable: "off", Backend: config.LocalBackendName}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
iface = &Interface{RoutingTable: "off", Backend: "mikrotik-xxx"}
assert.False(t, iface.ManageRoutingTable())
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
}
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface := &Interface{RoutingTable: ""}
iface := &Interface{RoutingTable: "", Backend: config.LocalBackendName}
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off"
@ -81,3 +125,17 @@ func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable())
}
func TestInterface_GetRoutingTableNonLocal(t *testing.T) {
iface := &Interface{RoutingTable: "off", Backend: "something different"}
assert.Equal(t, -1, iface.GetRoutingTable())
iface.RoutingTable = "0"
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "100"
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "abc"
assert.Equal(t, 0, iface.GetRoutingTable())
}

View File

@ -26,6 +26,10 @@ func (c Cidr) IsValid() bool {
return c.Prefix().IsValid()
}
func (c Cidr) EqualPrefix(other Cidr) bool {
return c.Addr == other.Addr && c.NetLength == other.NetLength
}
func CidrFromString(str string) (Cidr, error) {
prefix, err := netip.ParsePrefix(strings.TrimSpace(str))
if err != nil {
@ -199,3 +203,26 @@ func (c Cidr) Contains(other Cidr) bool {
return subnet.Contains(otherIP)
}
// ContainsDefaultRoute returns true if the given CIDRs contain a default route.
func ContainsDefaultRoute(cidrs []Cidr) bool {
for _, allowedIP := range cidrs {
if allowedIP.Prefix().Bits() == 0 {
return true
}
}
return false
}
// CidrsPerFamily returns a slice of CIDRs, one for each family (IPv4 and IPv6).
func CidrsPerFamily(cidrs []Cidr) (ipv4, ipv6 []Cidr) {
for _, cidr := range cidrs {
if cidr.IsV4() {
ipv4 = append(ipv4, cidr)
} else {
ipv6 = append(ipv6, cidr)
}
}
return
}

View File

@ -267,6 +267,7 @@ func parseHttpResponse[T any](resp *http.Response, err error) MikrotikApiRespons
}
defer func(Body io.ReadCloser) {
_, _ = io.Copy(io.Discard, Body) // ensure to empty the body
err := Body.Close()
if err != nil {
slog.Error("failed to close response body", "error", err)