+
diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go
index 25a261a..d91a5f6 100644
--- a/internal/adapters/wgcontroller/local.go
+++ b/internal/adapters/wgcontroller/local.go
@@ -639,22 +639,18 @@ func (c LocalController) exec(command string, interfaceId domain.InterfaceIdenti
// region routing-related
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
-func (c LocalController) SetRoutes(
- _ context.Context,
- interfaceId domain.InterfaceIdentifier,
- table int,
- fwMark uint32,
- cidrs []domain.Cidr,
-) error {
- slog.Debug("setting linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", cidrs)
+func (c LocalController) SetRoutes(_ context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("setting linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
+ "cidrs", info.AllowedIps)
link, err := c.nl.LinkByName(string(interfaceId))
if err != nil {
return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err)
}
- cidrsV4, cidrsV6 := domain.CidrsPerFamily(cidrs)
- realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+ realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, info.Table, info.FwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
}
@@ -664,8 +660,8 @@ func (c LocalController) SetRoutes(
}
currentFwMark := wgDev.FirewallMark
if int(realFwMark) != currentFwMark {
- slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", fwMark,
- "newFwMark", realFwMark, "oldTable", table, "newTable", realTable)
+ slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", currentFwMark,
+ "newFwMark", realFwMark, "oldTable", info.Table, "newTable", realTable)
if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil {
return fmt.Errorf("failed to update fwmark for interface %s to %d: %w", interfaceId, realFwMark, err)
}
@@ -874,14 +870,10 @@ func (c LocalController) getRulePriority(existingRules []netlink.Rule) int {
}
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
-func (c LocalController) RemoveRoutes(
- _ context.Context,
- interfaceId domain.InterfaceIdentifier,
- table int,
- fwMark uint32,
- oldCidrs []domain.Cidr,
-) error {
- slog.Debug("removing linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", oldCidrs)
+func (c LocalController) RemoveRoutes(_ context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("removing linux routes", "interface", interfaceId, "table", info.Table, "fwMark", info.FwMark,
+ "cidrs", info.AllowedIps)
wgDev, err := c.wg.Device(string(interfaceId))
if err != nil {
@@ -894,10 +886,12 @@ func (c LocalController) RemoveRoutes(
link = nil
}
- if wgDev != nil && fwMark == 0 {
+ fwMark := info.FwMark
+ if wgDev != nil && info.FwMark == 0 {
fwMark = uint32(wgDev.FirewallMark)
}
- if wgDev != nil && table == 0 {
+ table := info.Table
+ if wgDev != nil && info.Table == 0 {
table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior
}
linkIndex := -1
@@ -905,7 +899,7 @@ func (c LocalController) RemoveRoutes(
linkIndex = link.Attrs().Index
}
- cidrsV4, cidrsV6 := domain.CidrsPerFamily(oldCidrs)
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
realTable, realFwMark, err := c.getOrCreateRoutingTableAndFwMark(link, table, fwMark)
if err != nil {
return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err)
@@ -980,6 +974,10 @@ func (c LocalController) removeRoutesForFamily(
rawRoute.Dst = netlinkAddr.IpNet()
}
+ if rawRoute.Table != table {
+ continue // ignore routes from other tables
+ }
+
route := domain.CidrFromIpNet(*rawRoute.Dst)
if !slices.Contains(cidrs, route) {
continue // only remove routes that were previously added
diff --git a/internal/adapters/wgcontroller/mikrotik.go b/internal/adapters/wgcontroller/mikrotik.go
index 004f2f2..e730bd7 100644
--- a/internal/adapters/wgcontroller/mikrotik.go
+++ b/internal/adapters/wgcontroller/mikrotik.go
@@ -15,6 +15,9 @@ import (
"github.com/h44z/wg-portal/internal/lowlevel"
)
+const MikrotikRouteDistance = 5
+const MikrotikDefaultRoutingTable = "main"
+
type MikrotikController struct {
coreCfg *config.Config
cfg *config.BackendMikrotik
@@ -866,24 +869,302 @@ func (c *MikrotikController) UnsetDNS(
// region routing-related
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
-func (c *MikrotikController) SetRoutes(
+func (c *MikrotikController) SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("setting mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
+
+ // Mikrotik needs some time to apply the changes.
+ // If we don't wait, the routes might get created multiple times as the dynamic routes are not yet available.
+ time.Sleep(2 * time.Second)
+
+ tableName, err := c.getOrCreateRoutingTables(ctx, info.Interface.Identifier, info.TableStr)
+ if err != nil {
+ return fmt.Errorf("failed to get or create routing table for %s: %v", interfaceId, err)
+ }
+
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+
+ err = c.setRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
+ if err != nil {
+ return fmt.Errorf("failed to set IPv4 routes for %s: %v", interfaceId, err)
+ }
+
+ err = c.setRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
+ if err != nil {
+ return fmt.Errorf("failed to set IPv6 routes for %s: %v", interfaceId, err)
+ }
+
+ return nil
+}
+
+func (c *MikrotikController) resolveRouteTableName(name string) string {
+ name = strings.TrimSpace(name)
+
+ var mikrotikTableName string
+ switch strings.ToLower(name) {
+ case "", "0":
+ mikrotikTableName = MikrotikDefaultRoutingTable
+ case MikrotikDefaultRoutingTable:
+ return fmt.Sprintf("wgportal-%s",
+ MikrotikDefaultRoutingTable) // if the Mikrotik Main table should be used, the table-name should be left empty or set to "0".
+ default:
+ mikrotikTableName = name
+ }
+
+ return mikrotikTableName
+}
+
+func (c *MikrotikController) getOrCreateRoutingTables(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
- table int,
- fwMark uint32,
+ table string,
+) (string, error) {
+ // retrieve current routing tables
+ wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "dynamic", "fib", "name",
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return "", fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
+ }
+
+ wantedTableName := c.resolveRouteTableName(table)
+
+ // check if the table already exists
+ for _, table := range wgReply.Data {
+ if table.GetString("name") == wantedTableName {
+ return wantedTableName, nil // already exists, nothing to do
+ }
+ }
+
+ // create the table if it does not exist
+ createReply := c.client.Create(ctx, "/routing/table", lowlevel.GenericJsonObject{
+ "name": wantedTableName,
+ "comment": fmt.Sprintf("Routing Table for %s", interfaceId),
+ "fib": strconv.FormatBool(true),
+ })
+ if createReply.Status != lowlevel.MikrotikApiStatusOk {
+ return "", fmt.Errorf("failed to create routing table %s: %v", wantedTableName, createReply.Error)
+ }
+
+ return wantedTableName, nil
+}
+
+func (c *MikrotikController) setRoutesForFamily(
+ ctx context.Context,
+ interfaceId domain.InterfaceIdentifier,
+ ipV6 bool,
+ table string,
cidrs []domain.Cidr,
) error {
+ apiPath := "/ip/route"
+ if ipV6 {
+ apiPath = "/ipv6/route"
+ }
+
+ // retrieve current routes
+ wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
+ "routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
+ },
+ Filters: map[string]string{
+ "gateway": string(interfaceId),
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
+ }
+
+ // first create or update the routes
+ for _, cidr := range cidrs {
+ // check if the route already exists
+ exists := false
+ for _, route := range wgReply.Data {
+ existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
+ if err != nil {
+ slog.Warn("failed to parse route destination address",
+ "cidr", route.GetString("dst-address"), "error", err)
+ continue
+ }
+ if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
+ exists = true
+ break
+ }
+ }
+ if exists {
+ continue // route already exists, nothing to do
+ }
+
+ // create the route
+ reply := c.client.Create(ctx, apiPath, lowlevel.GenericJsonObject{
+ "gateway": string(interfaceId),
+ "dst-address": cidr.String(),
+ "distance": strconv.Itoa(MikrotikRouteDistance),
+ "disabled": strconv.FormatBool(false),
+ "routing-table": table,
+ })
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to create new route %s via %s: %v", cidr.String(), interfaceId, reply.Error)
+ }
+ }
+
+ // finally, remove the routes that are not in the new list
+ for _, route := range wgReply.Data {
+ if route.GetBool("dynamic") {
+ continue // dynamic routes are not managed by the controller, nothing to do
+ }
+
+ existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
+ if err != nil {
+ slog.Warn("failed to parse route destination address",
+ "cidr", route.GetString("dst-address"), "error", err)
+ continue
+ }
+
+ valid := false
+ for _, cidr := range cidrs {
+ if existingRoute.EqualPrefix(cidr) {
+ valid = true
+ break
+ }
+ }
+ if valid {
+ continue // route is still valid, nothing to do
+ }
+
+ // remove the route
+ reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to remove outdated route %s: %v", existingRoute.String(), reply.Error)
+ }
+ }
+
return nil
}
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
-func (c *MikrotikController) RemoveRoutes(
+func (c *MikrotikController) RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error {
+ interfaceId := info.Interface.Identifier
+ slog.Debug("removing mikrotik routes", "interface", interfaceId, "table", info.TableStr, "cidrs", info.AllowedIps)
+
+ tableName := c.resolveRouteTableName(info.TableStr)
+
+ cidrsV4, cidrsV6 := domain.CidrsPerFamily(info.AllowedIps)
+
+ err := c.removeRoutesForFamily(ctx, interfaceId, false, tableName, cidrsV4)
+ if err != nil {
+ return fmt.Errorf("failed to remove IPv4 routes for %s: %v", interfaceId, err)
+ }
+
+ err = c.removeRoutesForFamily(ctx, interfaceId, true, tableName, cidrsV6)
+ if err != nil {
+ return fmt.Errorf("failed to remove IPv6 routes for %s: %v", interfaceId, err)
+ }
+
+ err = c.removeRoutingTable(ctx, tableName)
+ if err != nil {
+ return fmt.Errorf("failed to remove routing table for %s: %v", interfaceId, err)
+ }
+
+ return nil
+}
+
+func (c *MikrotikController) removeRoutesForFamily(
ctx context.Context,
interfaceId domain.InterfaceIdentifier,
- table int,
- fwMark uint32,
- oldCidrs []domain.Cidr,
+ ipV6 bool,
+ table string,
+ cidrs []domain.Cidr,
) error {
+ apiPath := "/ip/route"
+ if ipV6 {
+ apiPath = "/ipv6/route"
+ }
+
+ // retrieve current routes
+ wgReply := c.client.Query(ctx, apiPath, &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "disabled", "inactive", "distance", "dst-address", "dynamic", "gateway", "immediate-gw",
+ "routing-table", "scope", "target-scope", "client-dns", "comment", "disabled", "responder",
+ },
+ Filters: map[string]string{
+ "gateway": string(interfaceId),
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to find WireGuard IP route settings (v6=%t): %v", ipV6, wgReply.Error)
+ }
+
+ // remove the routes from the list
+ for _, route := range wgReply.Data {
+ if route.GetBool("dynamic") {
+ continue // dynamic routes are not managed by the controller, nothing to do
+ }
+
+ existingRoute, err := domain.CidrFromString(route.GetString("dst-address"))
+ if err != nil {
+ slog.Warn("failed to parse route destination address",
+ "cidr", route.GetString("dst-address"), "error", err)
+ continue
+ }
+
+ remove := false
+ for _, cidr := range cidrs {
+ if existingRoute.EqualPrefix(cidr) && route.GetString("routing-table") == table {
+ remove = true
+ break
+ }
+ }
+ if !remove {
+ continue // route is still valid, nothing to do
+ }
+
+ // remove the route
+ reply := c.client.Delete(ctx, apiPath+"/"+route.GetString(".id"))
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to remove old route %s: %v", existingRoute.String(), reply.Error)
+ }
+ }
+
+ return nil
+}
+
+func (c *MikrotikController) removeRoutingTable(
+ ctx context.Context,
+ table string,
+) error {
+ if table == MikrotikDefaultRoutingTable {
+ return nil // we cannot remove the default table
+ }
+
+ // retrieve current routing tables
+ wgReply := c.client.Query(ctx, "/routing/table", &lowlevel.MikrotikRequestOptions{
+ PropList: []string{
+ ".id", "dynamic", "fib", "name",
+ },
+ })
+ if wgReply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("unable to query routing tables: %v", wgReply.Error)
+ }
+
+ for _, existingTable := range wgReply.Data {
+ if existingTable.GetBool("dynamic") {
+ continue // dynamic tables are not managed by the controller, nothing to do
+ }
+ if existingTable.GetString("name") != table {
+ continue // not the table we want to remove
+ }
+
+ // remove the table
+ reply := c.client.Delete(ctx, "/routing/table/"+existingTable.GetString(".id"))
+ if reply.Status != lowlevel.MikrotikApiStatusOk {
+ return fmt.Errorf("failed to remove routing table %s: %v", table, reply.Error)
+ }
+ return nil
+ }
+
return nil
}
diff --git a/internal/app/route/routes.go b/internal/app/route/routes.go
index 6a202a9..62cd67e 100644
--- a/internal/app/route/routes.go
+++ b/internal/app/route/routes.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
+ "sync"
"github.com/h44z/wg-portal/internal/app"
"github.com/h44z/wg-portal/internal/config"
@@ -29,21 +30,9 @@ type EventBus interface {
type RoutesController interface {
// SetRoutes sets the routes for the given interface. If no routes are provided, the function is a no-op.
- SetRoutes(
- ctx context.Context,
- interfaceId domain.InterfaceIdentifier,
- table int,
- fwMark uint32,
- cidrs []domain.Cidr,
- ) error
+ SetRoutes(ctx context.Context, info domain.RoutingTableInfo) error
// RemoveRoutes removes the routes for the given interface. If no routes are provided, the function is a no-op.
- RemoveRoutes(
- ctx context.Context,
- interfaceId domain.InterfaceIdentifier,
- table int,
- fwMark uint32,
- oldCidrs []domain.Cidr,
- ) error
+ RemoveRoutes(ctx context.Context, info domain.RoutingTableInfo) error
}
// endregion dependencies
@@ -64,6 +53,8 @@ type Manager struct {
bus EventBus
db InterfaceAndPeerDatabaseRepo
wgController ControllerManager
+
+ mux *sync.Mutex
}
// NewRouteManager creates a new route manager instance.
@@ -79,6 +70,7 @@ func NewRouteManager(
db: db,
wgController: wgController,
+ mux: &sync.Mutex{},
}
m.connectToMessageBus()
@@ -98,6 +90,9 @@ func (m Manager) StartBackgroundJobs(_ context.Context) {
}
func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
+ m.mux.Lock() // ensure that only one route update is processed at a time
+ defer m.mux.Unlock()
+
slog.Debug("handling route update event", "info", info.String())
if !info.ManagementEnabled() {
@@ -115,6 +110,9 @@ func (m Manager) handleRouteUpdateEvent(info domain.RoutingTableInfo) {
}
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
+ m.mux.Lock() // ensure that only one route update is processed at a time
+ defer m.mux.Unlock()
+
slog.Debug("handling route remove event", "info", info.String())
if !info.ManagementEnabled() {
@@ -144,7 +142,7 @@ func (m Manager) syncRoutes(ctx context.Context, info domain.RoutingTableInfo) e
return nil
}
- err := rc.SetRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
+ err := rc.SetRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to set routes for interface %s: %w", info.Interface.Identifier, err)
}
@@ -164,7 +162,7 @@ func (m Manager) removeRoutes(ctx context.Context, info domain.RoutingTableInfo)
return nil
}
- err := rc.RemoveRoutes(ctx, info.Interface.Identifier, info.Table, info.FwMark, info.AllowedIps)
+ err := rc.RemoveRoutes(ctx, info)
if err != nil {
return fmt.Errorf("failed to remove routes for interface %s: %w", info.Interface.Identifier, err)
}
diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go
index 42ff6c6..3dbe53f 100644
--- a/internal/app/wireguard/wireguard_interfaces.go
+++ b/internal/app/wireguard/wireguard_interfaces.go
@@ -467,6 +467,8 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif
AllowedIps: existingInterface.GetAllowedIPs(existingPeers),
FwMark: existingInterface.FirewallMark,
Table: existingInterface.GetRoutingTable(),
+ TableStr: existingInterface.RoutingTable,
+ IsDeleted: true,
})
now := time.Now()
@@ -518,7 +520,11 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("interface validation failed: %w", err)
}
- oldEnabled, newEnabled := m.getInterfaceStateHistory(ctx, iface)
+ oldEnabled, newEnabled, routeTableChanged := false, !iface.IsDisabled(), false // if the interface did not exist, we assume it was not enabled
+ oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
+ if err == nil {
+ oldEnabled, newEnabled, routeTableChanged = m.getInterfaceStateHistory(oldInterface, iface)
+ }
if err := m.handleInterfacePreSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
return nil, fmt.Errorf("pre-save hooks failed: %w", err)
@@ -528,7 +534,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return nil, fmt.Errorf("pre-save actions failed: %w", err)
}
- err := m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
+ err = m.db.SaveInterface(ctx, iface.Identifier, func(i *domain.Interface) (*domain.Interface, error) {
iface.CopyCalculatedAttributes(i)
err := m.wg.GetController(*iface).SaveInterface(ctx, iface.Identifier,
@@ -576,6 +582,7 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
})
} else {
m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{
@@ -583,7 +590,19 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
})
+ // if the route table changed, ensure that the old entries are remove
+ if routeTableChanged {
+ m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{
+ Interface: *oldInterface,
+ AllowedIps: oldInterface.GetAllowedIPs(peers),
+ FwMark: oldInterface.FirewallMark,
+ Table: oldInterface.GetRoutingTable(),
+ TableStr: oldInterface.RoutingTable,
+ IsDeleted: true, // mark the old entries as deleted
+ })
+ }
}
if err := m.handleInterfacePostSaveHooks(ctx, iface, oldEnabled, newEnabled); err != nil {
@@ -622,13 +641,11 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) (
return iface, nil
}
-func (m Manager) getInterfaceStateHistory(ctx context.Context, iface *domain.Interface) (oldEnabled, newEnabled bool) {
- oldInterface, err := m.db.GetInterface(ctx, iface.Identifier)
- if err != nil {
- return false, !iface.IsDisabled() // if the interface did not exist, we assume it was not enabled
- }
-
- return !oldInterface.IsDisabled(), !iface.IsDisabled()
+func (m Manager) getInterfaceStateHistory(
+ oldInterface *domain.Interface,
+ iface *domain.Interface,
+) (oldEnabled, newEnabled, routeTableChanged bool) {
+ return !oldInterface.IsDisabled(), !iface.IsDisabled(), oldInterface.RoutingTable != iface.RoutingTable
}
func (m Manager) handleInterfacePreSaveActions(ctx context.Context, iface *domain.Interface) error {
diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go
index 2c402db..e42af28 100644
--- a/internal/app/wireguard/wireguard_peers.go
+++ b/internal/app/wireguard/wireguard_peers.go
@@ -400,6 +400,7 @@ func (m Manager) DeletePeer(ctx context.Context, id domain.PeerIdentifier) error
AllowedIps: iface.GetAllowedIPs(peers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
})
// Update interface after peers have changed
m.bus.Publish(app.TopicPeerInterfaceUpdated, peer.InterfaceIdentifier)
@@ -505,6 +506,7 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error {
AllowedIps: iface.GetAllowedIPs(interfacePeers),
FwMark: iface.FirewallMark,
Table: iface.GetRoutingTable(),
+ TableStr: iface.RoutingTable,
})
}
diff --git a/internal/domain/interface.go b/internal/domain/interface.go
index dd9dfa8..01c720c 100644
--- a/internal/domain/interface.go
+++ b/internal/domain/interface.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/sys/unix"
"github.com/h44z/wg-portal/internal"
+ "github.com/h44z/wg-portal/internal/config"
)
const (
@@ -172,6 +173,7 @@ func (i *Interface) ManageRoutingTable() bool {
//
// -1 if RoutingTable was set to "off" or an error occurred
func (i *Interface) GetRoutingTable() int {
+
routingTableStr := strings.ToLower(i.RoutingTable)
switch {
case routingTableStr == "":
@@ -179,6 +181,9 @@ func (i *Interface) GetRoutingTable() int {
case routingTableStr == "off":
return -1
case strings.HasPrefix(routingTableStr, "0x"):
+ if i.Backend != config.LocalBackendName {
+ return 0 // ignore numeric routing table numbers for non-local controllers
+ }
numberStr := strings.ReplaceAll(routingTableStr, "0x", "")
routingTable, err := strconv.ParseUint(numberStr, 16, 64)
if err != nil {
@@ -191,6 +196,9 @@ func (i *Interface) GetRoutingTable() int {
}
return int(routingTable)
default:
+ if i.Backend != config.LocalBackendName {
+ return 0 // ignore numeric routing table numbers for non-local controllers
+ }
routingTable, err := strconv.Atoi(routingTableStr)
if err != nil {
slog.Error("failed to parse routing table number", "table", routingTableStr, "error", err)
@@ -325,11 +333,14 @@ type RoutingTableInfo struct {
AllowedIps []Cidr
FwMark uint32
Table int
+ TableStr string // the routing table number as string (used by mikrotik, linux uses the numeric value)
+ IsDeleted bool // true if the interface was deleted, false otherwise
}
func (r RoutingTableInfo) String() string {
v4, v6 := CidrsPerFamily(r.AllowedIps)
- return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table, len(v4), len(v6))
+ return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table,
+ len(v4), len(v6))
}
func (r RoutingTableInfo) ManagementEnabled() bool {
diff --git a/internal/domain/interface_test.go b/internal/domain/interface_test.go
index 54aa74d..9f0ee50 100644
--- a/internal/domain/interface_test.go
+++ b/internal/domain/interface_test.go
@@ -5,6 +5,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
+
+ "github.com/h44z/wg-portal/internal/config"
)
func TestInterface_IsDisabledReturnsTrueWhenDisabled(t *testing.T) {
@@ -37,8 +39,9 @@ func TestInterface_GetConfigFileNameReturnsCorrectFileName(t *testing.T) {
assert.Equal(t, expected, iface.GetConfigFileName())
}
-func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
+func TestInterface_GetAllowedIPsReturnsCorrectCidrsServerMode(t *testing.T) {
peer1 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
@@ -46,16 +49,45 @@ func TestInterface_GetAllowedIPsReturnsCorrectCidrs(t *testing.T) {
},
}
peer2 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
+ ExtraAllowedIPsStr: "10.20.2.2/32",
Interface: PeerInterfaceConfig{
Addresses: []Cidr{
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
},
},
}
- iface := &Interface{}
+ iface := &Interface{Type: InterfaceTypeServer}
expected := []Cidr{
{Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
{Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
+ {Cidr: "10.20.2.2/32", Addr: "10.20.2.2", NetLength: 32},
+ }
+ assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
+}
+
+func TestInterface_GetAllowedIPsReturnsCorrectCidrsClientMode(t *testing.T) {
+ peer1 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "192.168.2.2/32"},
+ Interface: PeerInterfaceConfig{
+ Addresses: []Cidr{
+ {Cidr: "192.168.1.2/32", Addr: "192.168.1.2", NetLength: 32},
+ },
+ },
+ }
+ peer2 := Peer{
+ AllowedIPsStr: ConfigOption[string]{Value: "10.0.2.2/32"},
+ ExtraAllowedIPsStr: "10.20.2.2/32",
+ Interface: PeerInterfaceConfig{
+ Addresses: []Cidr{
+ {Cidr: "10.0.0.2/32", Addr: "10.0.0.2", NetLength: 32},
+ },
+ },
+ }
+ iface := &Interface{Type: InterfaceTypeClient}
+ expected := []Cidr{
+ {Cidr: "192.168.2.2/32", Addr: "192.168.2.2", NetLength: 32},
+ {Cidr: "10.0.2.2/32", Addr: "10.0.2.2", NetLength: 32},
}
assert.Equal(t, expected, iface.GetAllowedIPs([]Peer{peer1, peer2}))
}
@@ -66,10 +98,22 @@ func TestInterface_ManageRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "100"
assert.True(t, iface.ManageRoutingTable())
+
+ iface = &Interface{RoutingTable: "off", Backend: config.LocalBackendName}
+ assert.False(t, iface.ManageRoutingTable())
+
+ iface.RoutingTable = "100"
+ assert.True(t, iface.ManageRoutingTable())
+
+ iface = &Interface{RoutingTable: "off", Backend: "mikrotik-xxx"}
+ assert.False(t, iface.ManageRoutingTable())
+
+ iface.RoutingTable = "100"
+ assert.True(t, iface.ManageRoutingTable())
}
func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
- iface := &Interface{RoutingTable: ""}
+ iface := &Interface{RoutingTable: "", Backend: config.LocalBackendName}
assert.Equal(t, 0, iface.GetRoutingTable())
iface.RoutingTable = "off"
@@ -81,3 +125,17 @@ func TestInterface_GetRoutingTableReturnsCorrectValue(t *testing.T) {
iface.RoutingTable = "200"
assert.Equal(t, 200, iface.GetRoutingTable())
}
+
+func TestInterface_GetRoutingTableNonLocal(t *testing.T) {
+ iface := &Interface{RoutingTable: "off", Backend: "something different"}
+ assert.Equal(t, -1, iface.GetRoutingTable())
+
+ iface.RoutingTable = "0"
+ assert.Equal(t, 0, iface.GetRoutingTable())
+
+ iface.RoutingTable = "100"
+ assert.Equal(t, 0, iface.GetRoutingTable())
+
+ iface.RoutingTable = "abc"
+ assert.Equal(t, 0, iface.GetRoutingTable())
+}
diff --git a/internal/domain/ip.go b/internal/domain/ip.go
index a62fcb4..6081b70 100644
--- a/internal/domain/ip.go
+++ b/internal/domain/ip.go
@@ -26,6 +26,10 @@ func (c Cidr) IsValid() bool {
return c.Prefix().IsValid()
}
+func (c Cidr) EqualPrefix(other Cidr) bool {
+ return c.Addr == other.Addr && c.NetLength == other.NetLength
+}
+
func CidrFromString(str string) (Cidr, error) {
prefix, err := netip.ParsePrefix(strings.TrimSpace(str))
if err != nil {