diff --git a/internal/adapters/wgcontroller/local.go b/internal/adapters/wgcontroller/local.go index 38e22ca..25a261a 100644 --- a/internal/adapters/wgcontroller/local.go +++ b/internal/adapters/wgcontroller/local.go @@ -658,7 +658,12 @@ func (c LocalController) SetRoutes( if err != nil { return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err) } - if realFwMark != fwMark { + wgDev, err := c.wg.Device(string(interfaceId)) + if err != nil { + return fmt.Errorf("failed to get wg device for %s: %w", interfaceId, err) + } + currentFwMark := wgDev.FirewallMark + if int(realFwMark) != currentFwMark { slog.Debug("updating fwmark for interface", "interface", interfaceId, "oldFwMark", fwMark, "newFwMark", realFwMark, "oldTable", table, "newTable", realTable) if err := c.updateFwMarkOnInterface(interfaceId, int(realFwMark)); err != nil { @@ -878,9 +883,26 @@ func (c LocalController) RemoveRoutes( ) error { slog.Debug("removing linux routes", "interface", interfaceId, "table", table, "fwMark", fwMark, "cidrs", oldCidrs) + wgDev, err := c.wg.Device(string(interfaceId)) + if err != nil { + slog.Debug("wg device already removed, route cleanup might be incomplete", "interface", interfaceId) + wgDev = nil + } link, err := c.nl.LinkByName(string(interfaceId)) if err != nil { - return fmt.Errorf("failed to find physical link for %s: %w", interfaceId, err) + slog.Debug("physical link already removed, route cleanup might be incomplete", "interface", interfaceId) + link = nil + } + + if wgDev != nil && fwMark == 0 { + fwMark = uint32(wgDev.FirewallMark) + } + if wgDev != nil && table == 0 { + table = wgDev.FirewallMark // use the fwMark as table, this is the default behavior + } + linkIndex := -1 + if link != nil { + linkIndex = link.Attrs().Index } cidrsV4, cidrsV6 := domain.CidrsPerFamily(oldCidrs) @@ -889,13 +911,26 @@ func (c LocalController) RemoveRoutes( return fmt.Errorf("failed to get or create routing table and fwmark for %s: %w", interfaceId, err) } - err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4) - if err != nil { - return fmt.Errorf("failed to remove v4 routes: %w", err) + if linkIndex > 0 { + err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V4, realTable, realFwMark, cidrsV4) + if err != nil { + return fmt.Errorf("failed to remove v4 routes: %w", err) + } + err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6) + if err != nil { + return fmt.Errorf("failed to remove v6 routes: %w", err) + } } - err = c.removeRoutesForFamily(interfaceId, link, netlink.FAMILY_V6, realTable, realFwMark, cidrsV6) - if err != nil { - return fmt.Errorf("failed to remove v6 routes: %w", err) + + if table > 0 { + err = c.removeRouteRulesForTable(netlink.FAMILY_V4, realTable) + if err != nil { + return fmt.Errorf("failed to remove v4 route rules for %s: %w", interfaceId, err) + } + err = c.removeRouteRulesForTable(netlink.FAMILY_V6, realTable) + if err != nil { + return fmt.Errorf("failed to remove v6 route rules for %s: %w", interfaceId, err) + } } return nil @@ -958,6 +993,25 @@ func (c LocalController) removeRoutesForFamily( return nil } +func (c LocalController) removeRouteRulesForTable( + family int, + table int, +) error { + existingRules, err := c.nl.RuleList(family) + if err != nil { + return fmt.Errorf("failed to get existing route rules for family-id %d: %w", family, err) + } + for _, existingRule := range existingRules { + if existingRule.Table == table { + err := c.nl.RuleDel(&existingRule) + if err != nil { + return fmt.Errorf("failed to delete old rule for table %d and family-id %d: %w", table, family, err) + } + } + } + return nil +} + // endregion routing-related // region statistics-related diff --git a/internal/app/wireguard/wireguard_interfaces.go b/internal/app/wireguard/wireguard_interfaces.go index e40c164..42ff6c6 100644 --- a/internal/app/wireguard/wireguard_interfaces.go +++ b/internal/app/wireguard/wireguard_interfaces.go @@ -462,12 +462,17 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("deletion not allowed: %w", err) } + m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ + Interface: *existingInterface, + AllowedIps: existingInterface.GetAllowedIPs(existingPeers), + FwMark: existingInterface.FirewallMark, + Table: existingInterface.GetRoutingTable(), + }) + now := time.Now() existingInterface.Disabled = &now // simulate a disabled interface existingInterface.DisabledReason = domain.DisabledReasonDeleted - physicalInterface, _ := m.wg.GetController(*existingInterface).GetInterface(ctx, id) - if err := m.handleInterfacePreSaveHooks(ctx, existingInterface, !existingInterface.IsDisabled(), false); err != nil { return fmt.Errorf("pre-delete hooks failed: %w", err) @@ -489,17 +494,6 @@ func (m Manager) DeleteInterface(ctx context.Context, id domain.InterfaceIdentif return fmt.Errorf("deletion failure: %w", err) } - fwMark := existingInterface.FirewallMark - if physicalInterface != nil && fwMark == 0 { - fwMark = physicalInterface.FirewallMark - } - m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ - Interface: *existingInterface, - AllowedIps: existingInterface.GetAllowedIPs(existingPeers), - FwMark: fwMark, - Table: existingInterface.GetRoutingTable(), - }) - if err := m.handleInterfacePostSaveHooks( ctx, existingInterface, @@ -577,15 +571,10 @@ func (m Manager) saveInterface(ctx context.Context, iface *domain.Interface) ( } if iface.IsDisabled() { - physicalInterface, _ := m.wg.GetController(*iface).GetInterface(ctx, iface.Identifier) - fwMark := iface.FirewallMark - if physicalInterface != nil && fwMark == 0 { - fwMark = physicalInterface.FirewallMark - } m.bus.Publish(app.TopicRouteRemove, domain.RoutingTableInfo{ Interface: *iface, AllowedIps: iface.GetAllowedIPs(peers), - FwMark: fwMark, + FwMark: iface.FirewallMark, Table: iface.GetRoutingTable(), }) } else { diff --git a/internal/app/wireguard/wireguard_peers.go b/internal/app/wireguard/wireguard_peers.go index bde31ce..2c402db 100644 --- a/internal/app/wireguard/wireguard_peers.go +++ b/internal/app/wireguard/wireguard_peers.go @@ -449,7 +449,6 @@ func (m Manager) GetUserPeerStats(ctx context.Context, id domain.UserIdentifier) func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { interfaces := make(map[domain.InterfaceIdentifier]domain.Interface) - interfacePeers := make(map[domain.InterfaceIdentifier][]domain.Peer) for _, peer := range peers { // get interface from db if it is not yet in the map @@ -462,7 +461,6 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { } iface := interfaces[peer.InterfaceIdentifier] - interfacePeers[iface.Identifier] = append(interfacePeers[iface.Identifier], *peer) // Always save the peer to the backend, regardless of disabled/expired state // The backend will handle the disabled state appropriately @@ -497,9 +495,14 @@ func (m Manager) savePeers(ctx context.Context, peers ...*domain.Peer) error { // Update routes after peers have changed for id, iface := range interfaces { + interfacePeers, err := m.db.GetInterfacePeers(ctx, id) + if err != nil { + return fmt.Errorf("failed to re-load peers for interface %s: %w", id, err) + } + m.bus.Publish(app.TopicRouteUpdate, domain.RoutingTableInfo{ Interface: iface, - AllowedIps: iface.GetAllowedIPs(interfacePeers[id]), + AllowedIps: iface.GetAllowedIPs(interfacePeers), FwMark: iface.FirewallMark, Table: iface.GetRoutingTable(), }) diff --git a/internal/domain/interface.go b/internal/domain/interface.go index 1ea4305..dd9dfa8 100644 --- a/internal/domain/interface.go +++ b/internal/domain/interface.go @@ -132,17 +132,30 @@ func (i *Interface) GetConfigFileName() string { return filename } +// GetAllowedIPs returns the allowed IPs for the interface depending on the interface type and peers. +// For example, if the interface type is Server, the allowed IPs are the IPs of the peers. +// If the interface type is Client, the allowed IPs correspond to the AllowedIPsStr of the peers. func (i *Interface) GetAllowedIPs(peers []Peer) []Cidr { var allowedCidrs []Cidr - for _, peer := range peers { - for _, ip := range peer.Interface.Addresses { - allowedCidrs = append(allowedCidrs, ip.HostAddr()) + switch i.Type { + case InterfaceTypeServer, InterfaceTypeAny: + for _, peer := range peers { + for _, ip := range peer.Interface.Addresses { + allowedCidrs = append(allowedCidrs, ip.HostAddr()) + } + if peer.ExtraAllowedIPsStr != "" { + extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr) + if err == nil { + allowedCidrs = append(allowedCidrs, extraIPs...) + } + } } - if peer.ExtraAllowedIPsStr != "" { - extraIPs, err := CidrsFromString(peer.ExtraAllowedIPsStr) + case InterfaceTypeClient: + for _, peer := range peers { + allowedIPs, err := CidrsFromString(peer.AllowedIPsStr.GetValue()) if err == nil { - allowedCidrs = append(allowedCidrs, extraIPs...) + allowedCidrs = append(allowedCidrs, allowedIPs...) } } } @@ -315,7 +328,8 @@ type RoutingTableInfo struct { } func (r RoutingTableInfo) String() string { - return fmt.Sprintf("%s: %d -> %d", r.Interface.Identifier, r.FwMark, r.Table) + v4, v6 := CidrsPerFamily(r.AllowedIps) + return fmt.Sprintf("%s: fwmark=%d; table=%d; routes_4=%d; routes_6=%d", r.Interface.Identifier, r.FwMark, r.Table, len(v4), len(v6)) } func (r RoutingTableInfo) ManagementEnabled() bool {