mirror of https://github.com/h44z/wg-portal.git
				
				
				
			
		
			
				
	
	
		
			504 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			504 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
package route
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"fmt"
 | 
						|
	"log/slog"
 | 
						|
 | 
						|
	"github.com/vishvananda/netlink"
 | 
						|
	"golang.org/x/sys/unix"
 | 
						|
	"golang.zx2c4.com/wireguard/wgctrl"
 | 
						|
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 | 
						|
 | 
						|
	"github.com/h44z/wg-portal/internal/app"
 | 
						|
	"github.com/h44z/wg-portal/internal/config"
 | 
						|
	"github.com/h44z/wg-portal/internal/domain"
 | 
						|
	"github.com/h44z/wg-portal/internal/lowlevel"
 | 
						|
)
 | 
						|
 | 
						|
// region dependencies
 | 
						|
 | 
						|
type InterfaceAndPeerDatabaseRepo interface {
 | 
						|
	// GetAllInterfaces returns all interfaces
 | 
						|
	GetAllInterfaces(ctx context.Context) ([]domain.Interface, error)
 | 
						|
	// GetInterfacePeers returns all peers for a given interface
 | 
						|
	GetInterfacePeers(ctx context.Context, id domain.InterfaceIdentifier) ([]domain.Peer, error)
 | 
						|
}
 | 
						|
 | 
						|
type EventBus interface {
 | 
						|
	// Subscribe subscribes to a topic
 | 
						|
	Subscribe(topic string, fn interface{}) error
 | 
						|
}
 | 
						|
 | 
						|
// endregion dependencies
 | 
						|
 | 
						|
type routeRuleInfo struct {
 | 
						|
	ifaceId    domain.InterfaceIdentifier
 | 
						|
	fwMark     uint32
 | 
						|
	table      int
 | 
						|
	family     int
 | 
						|
	hasDefault bool
 | 
						|
}
 | 
						|
 | 
						|
// Manager is try to mimic wg-quick behaviour (https://git.zx2c4.com/wireguard-tools/tree/src/wg-quick/linux.bash)
 | 
						|
// for default routes.
 | 
						|
type Manager struct {
 | 
						|
	cfg *config.Config
 | 
						|
 | 
						|
	bus EventBus
 | 
						|
	wg  lowlevel.WireGuardClient
 | 
						|
	nl  lowlevel.NetlinkClient
 | 
						|
	db  InterfaceAndPeerDatabaseRepo
 | 
						|
}
 | 
						|
 | 
						|
// NewRouteManager creates a new route manager instance.
 | 
						|
func NewRouteManager(cfg *config.Config, bus EventBus, db InterfaceAndPeerDatabaseRepo) (*Manager, error) {
 | 
						|
	wg, err := wgctrl.New()
 | 
						|
	if err != nil {
 | 
						|
		panic("failed to init wgctrl: " + err.Error())
 | 
						|
	}
 | 
						|
 | 
						|
	nl := &lowlevel.NetlinkManager{}
 | 
						|
 | 
						|
	m := &Manager{
 | 
						|
		cfg: cfg,
 | 
						|
		bus: bus,
 | 
						|
 | 
						|
		db: db,
 | 
						|
		wg: wg,
 | 
						|
		nl: nl,
 | 
						|
	}
 | 
						|
 | 
						|
	m.connectToMessageBus()
 | 
						|
 | 
						|
	return m, nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) connectToMessageBus() {
 | 
						|
	_ = m.bus.Subscribe(app.TopicRouteUpdate, m.handleRouteUpdateEvent)
 | 
						|
	_ = m.bus.Subscribe(app.TopicRouteRemove, m.handleRouteRemoveEvent)
 | 
						|
}
 | 
						|
 | 
						|
// StartBackgroundJobs starts background jobs for the route manager.
 | 
						|
// This method is non-blocking and returns immediately.
 | 
						|
func (m Manager) StartBackgroundJobs(_ context.Context) {
 | 
						|
	// this is a no-op for now
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) handleRouteUpdateEvent(srcDescription string) {
 | 
						|
	slog.Debug("handling route update event", "source", srcDescription)
 | 
						|
 | 
						|
	err := m.syncRoutes(context.Background())
 | 
						|
	if err != nil {
 | 
						|
		slog.Error("failed to synchronize routes",
 | 
						|
			"source", srcDescription,
 | 
						|
			"error", err)
 | 
						|
	}
 | 
						|
 | 
						|
	slog.Debug("routes synchronized", "source", srcDescription)
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) handleRouteRemoveEvent(info domain.RoutingTableInfo) {
 | 
						|
	slog.Debug("handling route remove event", "info", info.String())
 | 
						|
 | 
						|
	if !info.ManagementEnabled() {
 | 
						|
		return // route management disabled
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V4); err != nil {
 | 
						|
		slog.Error("failed to remove v4 fwmark rules", "error", err)
 | 
						|
	}
 | 
						|
	if err := m.removeFwMarkRules(info.FwMark, info.GetRoutingTable(), netlink.FAMILY_V6); err != nil {
 | 
						|
		slog.Error("failed to remove v6 fwmark rules", "error", err)
 | 
						|
	}
 | 
						|
 | 
						|
	slog.Debug("routes removed", "table", info.String())
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) syncRoutes(ctx context.Context) error {
 | 
						|
	interfaces, err := m.db.GetAllInterfaces(ctx)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to find all interfaces: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	rules := map[int][]routeRuleInfo{
 | 
						|
		netlink.FAMILY_V4: nil,
 | 
						|
		netlink.FAMILY_V6: nil,
 | 
						|
	}
 | 
						|
	for _, iface := range interfaces {
 | 
						|
		if iface.IsDisabled() {
 | 
						|
			continue // disabled interface does not need route entries
 | 
						|
		}
 | 
						|
		if !iface.ManageRoutingTable() {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		peers, err := m.db.GetInterfacePeers(ctx, iface.Identifier)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to find peers for %s: %w", iface.Identifier, err)
 | 
						|
		}
 | 
						|
		allowedIPs := iface.GetAllowedIPs(peers)
 | 
						|
		defRouteV4, defRouteV6 := m.containsDefaultRoute(allowedIPs)
 | 
						|
 | 
						|
		link, err := m.nl.LinkByName(string(iface.Identifier))
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to find physical link for %s: %w", iface.Identifier, err)
 | 
						|
		}
 | 
						|
 | 
						|
		table, fwmark, err := m.getRoutingTableAndFwMark(&iface, link)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to get table and fwmark for %s: %w", iface.Identifier, err)
 | 
						|
		}
 | 
						|
 | 
						|
		if err := m.setInterfaceRoutes(link, table, allowedIPs); err != nil {
 | 
						|
			return fmt.Errorf("failed to set routes for %s: %w", iface.Identifier, err)
 | 
						|
		}
 | 
						|
 | 
						|
		if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V4, allowedIPs); err != nil {
 | 
						|
			return fmt.Errorf("failed to remove deprecated v4 routes for %s: %w", iface.Identifier, err)
 | 
						|
		}
 | 
						|
		if err := m.removeDeprecatedRoutes(link, netlink.FAMILY_V6, allowedIPs); err != nil {
 | 
						|
			return fmt.Errorf("failed to remove deprecated v6 routes for %s: %w", iface.Identifier, err)
 | 
						|
		}
 | 
						|
 | 
						|
		if table != 0 {
 | 
						|
			rules[netlink.FAMILY_V4] = append(rules[netlink.FAMILY_V4], routeRuleInfo{
 | 
						|
				ifaceId:    iface.Identifier,
 | 
						|
				fwMark:     fwmark,
 | 
						|
				table:      table,
 | 
						|
				family:     netlink.FAMILY_V4,
 | 
						|
				hasDefault: defRouteV4,
 | 
						|
			})
 | 
						|
		}
 | 
						|
		if table != 0 {
 | 
						|
			rules[netlink.FAMILY_V6] = append(rules[netlink.FAMILY_V6], routeRuleInfo{
 | 
						|
				ifaceId:    iface.Identifier,
 | 
						|
				fwMark:     fwmark,
 | 
						|
				table:      table,
 | 
						|
				family:     netlink.FAMILY_V6,
 | 
						|
				hasDefault: defRouteV6,
 | 
						|
			})
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return m.syncRouteRules(rules)
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) syncRouteRules(allRules map[int][]routeRuleInfo) error {
 | 
						|
	for family, rules := range allRules {
 | 
						|
		// update fwmark rules
 | 
						|
		if err := m.setFwMarkRules(rules, family); err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		// update main rule
 | 
						|
		if err := m.setMainRule(rules, family); err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
 | 
						|
		// cleanup old main rules
 | 
						|
		if err := m.cleanupMainRule(rules, family); err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) setFwMarkRules(rules []routeRuleInfo, family int) error {
 | 
						|
	for _, rule := range rules {
 | 
						|
		existingRules, err := m.nl.RuleList(family)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
 | 
						|
		}
 | 
						|
 | 
						|
		ruleExists := false
 | 
						|
		for _, existingRule := range existingRules {
 | 
						|
			if rule.fwMark == existingRule.Mark && rule.table == existingRule.Table {
 | 
						|
				ruleExists = true
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if ruleExists {
 | 
						|
			continue // rule already exists, no need to recreate it
 | 
						|
		}
 | 
						|
 | 
						|
		// create missing rule
 | 
						|
		if err := m.nl.RuleAdd(&netlink.Rule{
 | 
						|
			Family:            family,
 | 
						|
			Table:             rule.table,
 | 
						|
			Mark:              rule.fwMark,
 | 
						|
			Invert:            true,
 | 
						|
			SuppressIfgroup:   -1,
 | 
						|
			SuppressPrefixlen: -1,
 | 
						|
			Priority:          m.getRulePriority(existingRules),
 | 
						|
			Mask:              nil,
 | 
						|
			Goto:              -1,
 | 
						|
			Flow:              -1,
 | 
						|
		}); err != nil {
 | 
						|
			return fmt.Errorf("failed to setup rule for fwmark %d and table %d: %w", rule.fwMark, rule.table, err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) removeFwMarkRules(fwmark uint32, table int, family int) error {
 | 
						|
	existingRules, err := m.nl.RuleList(family)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
 | 
						|
	}
 | 
						|
 | 
						|
	for _, existingRule := range existingRules {
 | 
						|
		if fwmark == existingRule.Mark && table == existingRule.Table {
 | 
						|
			existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
 | 
						|
			if err := m.nl.RuleDel(&existingRule); err != nil {
 | 
						|
				return fmt.Errorf("failed to delete fwmark rule: %w", err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) setMainRule(rules []routeRuleInfo, family int) error {
 | 
						|
	shouldHaveMainRule := false
 | 
						|
	for _, rule := range rules {
 | 
						|
		if rule.hasDefault == true {
 | 
						|
			shouldHaveMainRule = true
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if !shouldHaveMainRule {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
 | 
						|
	existingRules, err := m.nl.RuleList(family)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
 | 
						|
	}
 | 
						|
 | 
						|
	ruleExists := false
 | 
						|
	for _, existingRule := range existingRules {
 | 
						|
		if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
 | 
						|
			ruleExists = true
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if ruleExists {
 | 
						|
		return nil // rule already exists, skip re-creation
 | 
						|
	}
 | 
						|
 | 
						|
	if err := m.nl.RuleAdd(&netlink.Rule{
 | 
						|
		Family:            family,
 | 
						|
		Table:             unix.RT_TABLE_MAIN,
 | 
						|
		SuppressIfgroup:   -1,
 | 
						|
		SuppressPrefixlen: 0,
 | 
						|
		Priority:          m.getMainRulePriority(existingRules),
 | 
						|
		Mark:              0,
 | 
						|
		Mask:              nil,
 | 
						|
		Goto:              -1,
 | 
						|
		Flow:              -1,
 | 
						|
	}); err != nil {
 | 
						|
		return fmt.Errorf("failed to setup rule for main table: %w", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) cleanupMainRule(rules []routeRuleInfo, family int) error {
 | 
						|
	existingRules, err := m.nl.RuleList(family)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to get existing rules for family %d: %w", family, err)
 | 
						|
	}
 | 
						|
 | 
						|
	shouldHaveMainRule := false
 | 
						|
	for _, rule := range rules {
 | 
						|
		if rule.hasDefault == true {
 | 
						|
			shouldHaveMainRule = true
 | 
						|
			break
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	mainRules := 0
 | 
						|
	for _, existingRule := range existingRules {
 | 
						|
		if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
 | 
						|
			mainRules++
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	removalCount := 0
 | 
						|
	if mainRules > 1 {
 | 
						|
		removalCount = mainRules - 1 // we only want one single rule
 | 
						|
	}
 | 
						|
	if !shouldHaveMainRule {
 | 
						|
		removalCount = mainRules
 | 
						|
	}
 | 
						|
 | 
						|
	for _, existingRule := range existingRules {
 | 
						|
		if existingRule.Table == unix.RT_TABLE_MAIN && existingRule.SuppressPrefixlen == 0 {
 | 
						|
			if removalCount > 0 {
 | 
						|
				existingRule.Family = family // set family, somehow the RuleList method does not populate the family field
 | 
						|
				if err := m.nl.RuleDel(&existingRule); err != nil {
 | 
						|
					return fmt.Errorf("failed to delete main rule: %w", err)
 | 
						|
				}
 | 
						|
				removalCount--
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) getMainRulePriority(existingRules []netlink.Rule) int {
 | 
						|
	prio := m.cfg.Advanced.RulePrioOffset
 | 
						|
	for {
 | 
						|
		isFresh := true
 | 
						|
		for _, existingRule := range existingRules {
 | 
						|
			if existingRule.Priority == prio {
 | 
						|
				isFresh = false
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if isFresh {
 | 
						|
			break
 | 
						|
		} else {
 | 
						|
			prio++
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return prio
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) getRulePriority(existingRules []netlink.Rule) int {
 | 
						|
	prio := 32700 // linux main rule has a prio of 32766
 | 
						|
	for {
 | 
						|
		isFresh := true
 | 
						|
		for _, existingRule := range existingRules {
 | 
						|
			if existingRule.Priority == prio {
 | 
						|
				isFresh = false
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if isFresh {
 | 
						|
			break
 | 
						|
		} else {
 | 
						|
			prio--
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return prio
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) setInterfaceRoutes(link netlink.Link, table int, allowedIPs []domain.Cidr) error {
 | 
						|
	for _, allowedIP := range allowedIPs {
 | 
						|
		err := m.nl.RouteReplace(&netlink.Route{
 | 
						|
			LinkIndex: link.Attrs().Index,
 | 
						|
			Dst:       allowedIP.IpNet(),
 | 
						|
			Table:     table,
 | 
						|
			Scope:     unix.RT_SCOPE_LINK,
 | 
						|
			Type:      unix.RTN_UNICAST,
 | 
						|
		})
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to add/update route %s: %w", allowedIP.String(), err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) removeDeprecatedRoutes(link netlink.Link, family int, allowedIPs []domain.Cidr) error {
 | 
						|
	rawRoutes, err := m.nl.RouteListFiltered(family, &netlink.Route{
 | 
						|
		LinkIndex: link.Attrs().Index,
 | 
						|
		Table:     unix.RT_TABLE_UNSPEC, // all tables
 | 
						|
		Scope:     unix.RT_SCOPE_LINK,
 | 
						|
		Type:      unix.RTN_UNICAST,
 | 
						|
	}, netlink.RT_FILTER_TABLE|netlink.RT_FILTER_TYPE|netlink.RT_FILTER_OIF)
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to fetch raw routes: %w", err)
 | 
						|
	}
 | 
						|
	for _, rawRoute := range rawRoutes {
 | 
						|
		if rawRoute.Dst == nil { // handle default route
 | 
						|
			var netlinkAddr domain.Cidr
 | 
						|
			if family == netlink.FAMILY_V4 {
 | 
						|
				netlinkAddr, _ = domain.CidrFromString("0.0.0.0/0")
 | 
						|
			} else {
 | 
						|
				netlinkAddr, _ = domain.CidrFromString("::/0")
 | 
						|
			}
 | 
						|
			rawRoute.Dst = netlinkAddr.IpNet()
 | 
						|
		}
 | 
						|
 | 
						|
		netlinkAddr := domain.CidrFromIpNet(*rawRoute.Dst)
 | 
						|
		remove := true
 | 
						|
		for _, allowedIP := range allowedIPs {
 | 
						|
			if netlinkAddr == allowedIP {
 | 
						|
				remove = false
 | 
						|
				break
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if !remove {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		err := m.nl.RouteDel(&rawRoute)
 | 
						|
		if err != nil {
 | 
						|
			return fmt.Errorf("failed to remove deprecated route %s: %w", netlinkAddr.String(), err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) getRoutingTableAndFwMark(iface *domain.Interface, link netlink.Link) (
 | 
						|
	table int,
 | 
						|
	fwmark uint32,
 | 
						|
	err error,
 | 
						|
) {
 | 
						|
	table = iface.GetRoutingTable()
 | 
						|
	fwmark = iface.FirewallMark
 | 
						|
 | 
						|
	if fwmark == 0 {
 | 
						|
		// generate a new (temporary) firewall mark based on the interface index
 | 
						|
		fwmark = uint32(m.cfg.Advanced.RouteTableOffset + link.Attrs().Index)
 | 
						|
		slog.Debug("using fwmark to handle routes",
 | 
						|
			"interface", iface.Identifier,
 | 
						|
			"fwmark", fwmark)
 | 
						|
 | 
						|
		// apply the temporary fwmark to the wireguard interface
 | 
						|
		err = m.setFwMark(iface.Identifier, int(fwmark))
 | 
						|
	}
 | 
						|
	if table == 0 {
 | 
						|
		table = int(fwmark) // generate a new routing table base on interface index
 | 
						|
		slog.Debug("using routing table to handle default routes",
 | 
						|
			"interface", iface.Identifier,
 | 
						|
			"table", table)
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) setFwMark(id domain.InterfaceIdentifier, fwmark int) error {
 | 
						|
	err := m.wg.ConfigureDevice(string(id), wgtypes.Config{
 | 
						|
		FirewallMark: &fwmark,
 | 
						|
	})
 | 
						|
	if err != nil {
 | 
						|
		return fmt.Errorf("failed to update fwmark to: %d: %w", fwmark, err)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func (m Manager) containsDefaultRoute(allowedIPs []domain.Cidr) (ipV4, ipV6 bool) {
 | 
						|
	for _, allowedIP := range allowedIPs {
 | 
						|
		if ipV4 && ipV6 {
 | 
						|
			break // speed up
 | 
						|
		}
 | 
						|
 | 
						|
		if allowedIP.Prefix().Bits() == 0 {
 | 
						|
			if allowedIP.IsV4() {
 | 
						|
				ipV4 = true
 | 
						|
			} else {
 | 
						|
				ipV6 = true
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 |