830 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			830 lines
		
	
	
		
			21 KiB
		
	
	
	
		
			Go
		
	
	
	
| package util
 | |
| 
 | |
| import (
 | |
| 	"bufio"
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"io/fs"
 | |
| 	"math/rand"
 | |
| 	"net"
 | |
| 	"os"
 | |
| 	"path"
 | |
| 	"path/filepath"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"text/template"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/ngoduykhanh/wireguard-ui/store"
 | |
| 	"github.com/ngoduykhanh/wireguard-ui/telegram"
 | |
| 	"github.com/skip2/go-qrcode"
 | |
| 	"golang.org/x/mod/sumdb/dirhash"
 | |
| 
 | |
| 	externalip "github.com/glendc/go-external-ip"
 | |
| 	"github.com/labstack/gommon/log"
 | |
| 	"github.com/ngoduykhanh/wireguard-ui/model"
 | |
| 	"github.com/sdomino/scribble"
 | |
| )
 | |
| 
 | |
| var qrCodeSettings = model.QRCodeSettings{
 | |
| 	Enabled:    true,
 | |
| 	IncludeDNS: true,
 | |
| 	IncludeMTU: true,
 | |
| }
 | |
| 
 | |
| // BuildClientConfig to create wireguard client config string
 | |
| func BuildClientConfig(client model.Client, server model.Server, setting model.GlobalSetting) string {
 | |
| 	// Interface section
 | |
| 	clientAddress := fmt.Sprintf("Address = %s\n", strings.Join(client.AllocatedIPs, ","))
 | |
| 	clientPrivateKey := fmt.Sprintf("PrivateKey = %s\n", client.PrivateKey)
 | |
| 	clientDNS := ""
 | |
| 	if client.UseServerDNS {
 | |
| 		clientDNS = fmt.Sprintf("DNS = %s\n", strings.Join(setting.DNSServers, ","))
 | |
| 	}
 | |
| 	clientMTU := ""
 | |
| 	if setting.MTU > 0 {
 | |
| 		clientMTU = fmt.Sprintf("MTU = %d\n", setting.MTU)
 | |
| 	}
 | |
| 
 | |
| 	// Peer section
 | |
| 	peerPublicKey := fmt.Sprintf("PublicKey = %s\n", server.KeyPair.PublicKey)
 | |
| 	peerPresharedKey := ""
 | |
| 	if client.PresharedKey != "" {
 | |
| 		peerPresharedKey = fmt.Sprintf("PresharedKey = %s\n", client.PresharedKey)
 | |
| 	}
 | |
| 
 | |
| 	peerAllowedIPs := fmt.Sprintf("AllowedIPs = %s\n", strings.Join(client.AllowedIPs, ","))
 | |
| 
 | |
| 	desiredHost := setting.EndpointAddress
 | |
| 	desiredPort := server.Interface.ListenPort
 | |
| 	if strings.Contains(desiredHost, ":") {
 | |
| 		split := strings.Split(desiredHost, ":")
 | |
| 		desiredHost = split[0]
 | |
| 		if n, err := strconv.Atoi(split[1]); err == nil {
 | |
| 			desiredPort = n
 | |
| 		} else {
 | |
| 			log.Error("Endpoint appears to be incorrectly formatted: ", err)
 | |
| 		}
 | |
| 	}
 | |
| 	peerEndpoint := fmt.Sprintf("Endpoint = %s:%d\n", desiredHost, desiredPort)
 | |
| 
 | |
| 	peerPersistentKeepalive := ""
 | |
| 	if setting.PersistentKeepalive > 0 {
 | |
| 		peerPersistentKeepalive = fmt.Sprintf("PersistentKeepalive = %d\n", setting.PersistentKeepalive)
 | |
| 	}
 | |
| 
 | |
| 	// build the config as string
 | |
| 	strConfig := "[Interface]\n" +
 | |
| 		clientAddress +
 | |
| 		clientPrivateKey +
 | |
| 		clientDNS +
 | |
| 		clientMTU +
 | |
| 		"\n[Peer]\n" +
 | |
| 		peerPublicKey +
 | |
| 		peerPresharedKey +
 | |
| 		peerAllowedIPs +
 | |
| 		peerEndpoint +
 | |
| 		peerPersistentKeepalive
 | |
| 
 | |
| 	return strConfig
 | |
| }
 | |
| 
 | |
| // ClientDefaultsFromEnv to read the default values for creating a new client from the environment or use sane defaults
 | |
| func ClientDefaultsFromEnv() model.ClientDefaults {
 | |
| 	clientDefaults := model.ClientDefaults{}
 | |
| 	clientDefaults.AllowedIps = LookupEnvOrStrings(DefaultClientAllowedIpsEnvVar, []string{"0.0.0.0/0"})
 | |
| 	clientDefaults.ExtraAllowedIps = LookupEnvOrStrings(DefaultClientExtraAllowedIpsEnvVar, []string{})
 | |
| 	clientDefaults.UseServerDNS = LookupEnvOrBool(DefaultClientUseServerDNSEnvVar, true)
 | |
| 	clientDefaults.EnableAfterCreation = LookupEnvOrBool(DefaultClientEnableAfterCreationEnvVar, true)
 | |
| 
 | |
| 	return clientDefaults
 | |
| }
 | |
| 
 | |
| // ContainsCIDR to check if ipnet1 contains ipnet2
 | |
| // https://stackoverflow.com/a/40406619/6111641
 | |
| // https://go.dev/play/p/Q4J-JEN3sF
 | |
| func ContainsCIDR(ipnet1, ipnet2 *net.IPNet) bool {
 | |
| 	ones1, _ := ipnet1.Mask.Size()
 | |
| 	ones2, _ := ipnet2.Mask.Size()
 | |
| 	return ones1 <= ones2 && ipnet1.Contains(ipnet2.IP)
 | |
| }
 | |
| 
 | |
| // ValidateCIDR to validate a network CIDR
 | |
| func ValidateCIDR(cidr string) bool {
 | |
| 	_, _, err := net.ParseCIDR(cidr)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // ValidateCIDRList to validate a list of network CIDR
 | |
| func ValidateCIDRList(cidrs []string, allowEmpty bool) bool {
 | |
| 	for _, cidr := range cidrs {
 | |
| 		if allowEmpty {
 | |
| 			if len(cidr) > 0 {
 | |
| 				if ValidateCIDR(cidr) == false {
 | |
| 					return false
 | |
| 				}
 | |
| 			}
 | |
| 		} else {
 | |
| 			if ValidateCIDR(cidr) == false {
 | |
| 				return false
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // ValidateAllowedIPs to validate allowed ip addresses in CIDR format
 | |
| func ValidateAllowedIPs(cidrs []string) bool {
 | |
| 	if ValidateCIDRList(cidrs, false) == false {
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // ValidateExtraAllowedIPs to validate extra Allowed ip addresses, allowing empty strings
 | |
| func ValidateExtraAllowedIPs(cidrs []string) bool {
 | |
| 	if ValidateCIDRList(cidrs, true) == false {
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // ValidateServerAddresses to validate allowed ip addresses in CIDR format
 | |
| func ValidateServerAddresses(cidrs []string) bool {
 | |
| 	if ValidateCIDRList(cidrs, false) == false {
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // ValidateIPAddress to validate the IPv4 and IPv6 address
 | |
| func ValidateIPAddress(ip string) bool {
 | |
| 	if net.ParseIP(ip) == nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // ValidateIPAddressList to validate a list of IPv4 and IPv6 addresses
 | |
| func ValidateIPAddressList(ips []string) bool {
 | |
| 	for _, ip := range ips {
 | |
| 		if ValidateIPAddress(ip) == false {
 | |
| 			return false
 | |
| 		}
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| // GetInterfaceIPs to get local machine's interface ip addresses
 | |
| func GetInterfaceIPs() ([]model.Interface, error) {
 | |
| 	// get machine's interfaces
 | |
| 	ifaces, err := net.Interfaces()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	var interfaceList []model.Interface
 | |
| 
 | |
| 	// get interface's ip addresses
 | |
| 	for _, i := range ifaces {
 | |
| 		addrs, err := i.Addrs()
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		for _, addr := range addrs {
 | |
| 			var ip net.IP
 | |
| 			switch v := addr.(type) {
 | |
| 			case *net.IPNet:
 | |
| 				ip = v.IP
 | |
| 			case *net.IPAddr:
 | |
| 				ip = v.IP
 | |
| 			}
 | |
| 			if ip == nil || ip.IsLoopback() {
 | |
| 				continue
 | |
| 			}
 | |
| 			ip = ip.To4()
 | |
| 			if ip == nil {
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			iface := model.Interface{}
 | |
| 			iface.Name = i.Name
 | |
| 			iface.IPAddress = ip.String()
 | |
| 			interfaceList = append(interfaceList, iface)
 | |
| 		}
 | |
| 	}
 | |
| 	return interfaceList, err
 | |
| }
 | |
| 
 | |
| // GetPublicIP to get machine's public ip address
 | |
| func GetPublicIP() (model.Interface, error) {
 | |
| 	// set time out to 5 seconds
 | |
| 	cfg := externalip.ConsensusConfig{}
 | |
| 	cfg.Timeout = time.Second * 5
 | |
| 	consensus := externalip.NewConsensus(&cfg, nil)
 | |
| 
 | |
| 	// add trusted voters
 | |
| 	consensus.AddVoter(externalip.NewHTTPSource("https://checkip.amazonaws.com/"), 1)
 | |
| 	consensus.AddVoter(externalip.NewHTTPSource("http://whatismyip.akamai.com"), 1)
 | |
| 	consensus.AddVoter(externalip.NewHTTPSource("https://ifconfig.top"), 1)
 | |
| 
 | |
| 	publicInterface := model.Interface{}
 | |
| 	publicInterface.Name = "Public Address"
 | |
| 
 | |
| 	ip, err := consensus.ExternalIP()
 | |
| 	if err != nil {
 | |
| 		publicInterface.IPAddress = "N/A"
 | |
| 	} else {
 | |
| 		publicInterface.IPAddress = ip.String()
 | |
| 	}
 | |
| 
 | |
| 	// error handling happened above, no need to pass it through
 | |
| 	return publicInterface, nil
 | |
| }
 | |
| 
 | |
| // GetIPFromCIDR get ip from CIDR
 | |
| func GetIPFromCIDR(cidr string) (string, error) {
 | |
| 	ip, _, err := net.ParseCIDR(cidr)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return ip.String(), nil
 | |
| }
 | |
| 
 | |
| // GetAllocatedIPs to get all ip addresses allocated to clients and server
 | |
| func GetAllocatedIPs(ignoreClientID string) ([]string, error) {
 | |
| 	allocatedIPs := make([]string, 0)
 | |
| 
 | |
| 	// initialize database directory
 | |
| 	dir := "./db"
 | |
| 	db, err := scribble.New(dir, nil)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// read server information
 | |
| 	serverInterface := model.ServerInterface{}
 | |
| 	if err := db.Read("server", "interfaces", &serverInterface); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// append server's addresses to the result
 | |
| 	for _, cidr := range serverInterface.Addresses {
 | |
| 		ip, err := GetIPFromCIDR(cidr)
 | |
| 		if err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 		allocatedIPs = append(allocatedIPs, ip)
 | |
| 	}
 | |
| 
 | |
| 	// read client information
 | |
| 	records, err := db.ReadAll("clients")
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// append client's addresses to the result
 | |
| 	for _, f := range records {
 | |
| 		client := model.Client{}
 | |
| 		if err := json.Unmarshal(f, &client); err != nil {
 | |
| 			return nil, err
 | |
| 		}
 | |
| 
 | |
| 		if client.ID != ignoreClientID {
 | |
| 			for _, cidr := range client.AllocatedIPs {
 | |
| 				ip, err := GetIPFromCIDR(cidr)
 | |
| 				if err != nil {
 | |
| 					return nil, err
 | |
| 				}
 | |
| 				allocatedIPs = append(allocatedIPs, ip)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return allocatedIPs, nil
 | |
| }
 | |
| 
 | |
| // inc from https://play.golang.org/p/m8TNTtygK0
 | |
| func inc(ip net.IP) {
 | |
| 	for j := len(ip) - 1; j >= 0; j-- {
 | |
| 		ip[j]++
 | |
| 		if ip[j] > 0 {
 | |
| 			break
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // GetBroadcastIP func to get the broadcast ip address of a network
 | |
| func GetBroadcastIP(n *net.IPNet) net.IP {
 | |
| 	var broadcast net.IP
 | |
| 	if len(n.IP) == 4 {
 | |
| 		broadcast = net.ParseIP("0.0.0.0").To4()
 | |
| 	} else {
 | |
| 		broadcast = net.ParseIP("::")
 | |
| 	}
 | |
| 	for i := 0; i < len(n.IP); i++ {
 | |
| 		broadcast[i] = n.IP[i] | ^n.Mask[i]
 | |
| 	}
 | |
| 	return broadcast
 | |
| }
 | |
| 
 | |
| // GetBroadcastAndNetworkAddrsLookup get the ip address that can't be used with current server interfaces
 | |
| func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]bool {
 | |
| 	list := make(map[string]bool)
 | |
| 	for _, ifa := range interfaceAddresses {
 | |
| 		_, netAddr, err := net.ParseCIDR(ifa)
 | |
| 		if err != nil {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		broadcastAddr := GetBroadcastIP(netAddr).String()
 | |
| 		networkAddr := netAddr.IP.String()
 | |
| 		list[broadcastAddr] = true
 | |
| 		list[networkAddr] = true
 | |
| 	}
 | |
| 	return list
 | |
| }
 | |
| 
 | |
| // GetAvailableIP get the ip address that can be allocated from an CIDR
 | |
| // We need interfaceAddresses to find real broadcast and network addresses
 | |
| func GetAvailableIP(cidr string, allocatedList, interfaceAddresses []string) (string, error) {
 | |
| 	ip, netAddr, err := net.ParseCIDR(cidr)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	unavailableIPs := GetBroadcastAndNetworkAddrsLookup(interfaceAddresses)
 | |
| 
 | |
| 	for ip := ip.Mask(netAddr.Mask); netAddr.Contains(ip); inc(ip) {
 | |
| 		available := true
 | |
| 		suggestedAddr := ip.String()
 | |
| 		for _, allocatedAddr := range allocatedList {
 | |
| 			if suggestedAddr == allocatedAddr {
 | |
| 				available = false
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 		if available && !unavailableIPs[suggestedAddr] {
 | |
| 			return suggestedAddr, nil
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return "", errors.New("no more available ip address")
 | |
| }
 | |
| 
 | |
| // ValidateIPAllocation to validate the list of client's ip allocation
 | |
| // They must have a correct format and available in serverAddresses space
 | |
| func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ipAllocationList []string) (bool, error) {
 | |
| 	for _, clientCIDR := range ipAllocationList {
 | |
| 		ip, _, _ := net.ParseCIDR(clientCIDR)
 | |
| 
 | |
| 		// clientCIDR must be in CIDR format
 | |
| 		if ip == nil {
 | |
| 			return false, fmt.Errorf("invalid ip allocation input %s. Must be in CIDR format", clientCIDR)
 | |
| 		}
 | |
| 
 | |
| 		// return false immediately if the ip is already in use (in ipAllocatedList)
 | |
| 		for _, item := range ipAllocatedList {
 | |
| 			if item == ip.String() {
 | |
| 				return false, fmt.Errorf("IP %s already allocated", ip)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// even if it is not in use, we still need to check if it
 | |
| 		// belongs to a network of the server.
 | |
| 		var isValid = false
 | |
| 		for _, serverCIDR := range serverAddresses {
 | |
| 			_, serverNet, _ := net.ParseCIDR(serverCIDR)
 | |
| 			if serverNet.Contains(ip) {
 | |
| 				isValid = true
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// current ip allocation is valid, check the next one
 | |
| 		if isValid {
 | |
| 			continue
 | |
| 		} else {
 | |
| 			return false, fmt.Errorf("IP %s does not belong to any network addresses of WireGuard server", ip)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return true, nil
 | |
| }
 | |
| 
 | |
| // findSubnetRangeForIP to find first SR for IP, and cache the match
 | |
| func findSubnetRangeForIP(cidr string) (uint16, error) {
 | |
| 	ip, _, err := net.ParseCIDR(cidr)
 | |
| 	if err != nil {
 | |
| 		return 0, err
 | |
| 	}
 | |
| 
 | |
| 	if srName, ok := IPToSubnetRange[ip.String()]; ok {
 | |
| 		return srName, nil
 | |
| 	}
 | |
| 
 | |
| 	for srIndex, sr := range SubnetRangesOrder {
 | |
| 		for _, srCIDR := range SubnetRanges[sr] {
 | |
| 			if srCIDR.Contains(ip) {
 | |
| 				IPToSubnetRange[ip.String()] = uint16(srIndex)
 | |
| 				return uint16(srIndex), nil
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return 0, fmt.Errorf("subnet range not found for this IP")
 | |
| }
 | |
| 
 | |
| // FillClientSubnetRange to fill subnet ranges client belongs to, does nothing if SRs are not found
 | |
| func FillClientSubnetRange(client model.ClientData) model.ClientData {
 | |
| 	cl := *client.Client
 | |
| 	for _, ip := range cl.AllocatedIPs {
 | |
| 		sr, err := findSubnetRangeForIP(ip)
 | |
| 		if err != nil {
 | |
| 			continue
 | |
| 		}
 | |
| 		cl.SubnetRanges = append(cl.SubnetRanges, SubnetRangesOrder[sr])
 | |
| 	}
 | |
| 	return model.ClientData{
 | |
| 		Client: &cl,
 | |
| 		QRCode: client.QRCode,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // ValidateAndFixSubnetRanges to check if subnet ranges are valid for the server configuration
 | |
| // Removes all non-valid CIDRs
 | |
| func ValidateAndFixSubnetRanges(db store.IStore) error {
 | |
| 	if len(SubnetRangesOrder) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	server, err := db.GetServer()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	var serverSubnets []*net.IPNet
 | |
| 	for _, addr := range server.Interface.Addresses {
 | |
| 		addr = strings.TrimSpace(addr)
 | |
| 		_, netAddr, err := net.ParseCIDR(addr)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		serverSubnets = append(serverSubnets, netAddr)
 | |
| 	}
 | |
| 
 | |
| 	for _, rng := range SubnetRangesOrder {
 | |
| 		cidrs := SubnetRanges[rng]
 | |
| 		if len(cidrs) > 0 {
 | |
| 			newCIDRs := make([]*net.IPNet, 0)
 | |
| 			for _, cidr := range cidrs {
 | |
| 				valid := false
 | |
| 
 | |
| 				for _, serverSubnet := range serverSubnets {
 | |
| 					if ContainsCIDR(serverSubnet, cidr) {
 | |
| 						valid = true
 | |
| 						break
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				if valid {
 | |
| 					newCIDRs = append(newCIDRs, cidr)
 | |
| 				} else {
 | |
| 					log.Warnf("[%v] CIDR is outside of all server subnets: %v. Removed.", rng, cidr)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			if len(newCIDRs) > 0 {
 | |
| 				SubnetRanges[rng] = newCIDRs
 | |
| 			} else {
 | |
| 				delete(SubnetRanges, rng)
 | |
| 				log.Warnf("[%v] No valid CIDRs in this subnet range. Removed.", rng)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // GetSubnetRangesString to get a formatted string, representing active subnet ranges
 | |
| func GetSubnetRangesString() string {
 | |
| 	if len(SubnetRangesOrder) == 0 {
 | |
| 		return ""
 | |
| 	}
 | |
| 
 | |
| 	strB := strings.Builder{}
 | |
| 
 | |
| 	for _, rng := range SubnetRangesOrder {
 | |
| 		cidrs := SubnetRanges[rng]
 | |
| 		if len(cidrs) > 0 {
 | |
| 			strB.WriteString(rng)
 | |
| 			strB.WriteString(":[")
 | |
| 			first := true
 | |
| 			for _, cidr := range cidrs {
 | |
| 				if !first {
 | |
| 					strB.WriteString(", ")
 | |
| 				}
 | |
| 				strB.WriteString(cidr.String())
 | |
| 				first = false
 | |
| 			}
 | |
| 			strB.WriteString("]  ")
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return strings.TrimSpace(strB.String())
 | |
| }
 | |
| 
 | |
| // WriteWireGuardServerConfig to write Wireguard server config. e.g. wg0.conf
 | |
| func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, clientDataList []model.ClientData, usersList []model.User, globalSettings model.GlobalSetting) error {
 | |
| 	var tmplWireguardConf string
 | |
| 
 | |
| 	// if set, read wg.conf template from WgConfTemplate
 | |
| 	if len(WgConfTemplate) > 0 {
 | |
| 		fileContentBytes, err := os.ReadFile(WgConfTemplate)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		tmplWireguardConf = string(fileContentBytes)
 | |
| 	} else {
 | |
| 		// read default wg.conf template file to string
 | |
| 		fileContent, err := StringFromEmbedFile(tmplDir, "wg.conf")
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		tmplWireguardConf = fileContent
 | |
| 	}
 | |
| 
 | |
| 	// parse the template
 | |
| 	t, err := template.New("wg_config").Parse(tmplWireguardConf)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// write config file to disk
 | |
| 	f, err := os.Create(globalSettings.ConfigFilePath)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	config := map[string]interface{}{
 | |
| 		"serverConfig":   serverConfig,
 | |
| 		"clientDataList": clientDataList,
 | |
| 		"globalSettings": globalSettings,
 | |
| 		"usersList":      usersList,
 | |
| 	}
 | |
| 
 | |
| 	err = t.Execute(f, config)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	f.Close()
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // SendRequestedConfigsToTelegram to send client all their configs. Returns failed configs list.
 | |
| func SendRequestedConfigsToTelegram(db store.IStore, userid int64) []string {
 | |
| 	failedList := make([]string, 0)
 | |
| 	TgUseridToClientIDMutex.RLock()
 | |
| 	if clids, found := TgUseridToClientID[userid]; found && len(clids) > 0 {
 | |
| 		TgUseridToClientIDMutex.RUnlock()
 | |
| 
 | |
| 		for _, clid := range clids {
 | |
| 			clientData, err := db.GetClientByID(clid, qrCodeSettings)
 | |
| 			if err != nil {
 | |
| 				// return fmt.Errorf("unable to get client")
 | |
| 				failedList = append(failedList, clid)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			// build config
 | |
| 			server, _ := db.GetServer()
 | |
| 			globalSettings, _ := db.GetGlobalSettings()
 | |
| 			config := BuildClientConfig(*clientData.Client, server, globalSettings)
 | |
| 			configData := []byte(config)
 | |
| 			var qrData []byte
 | |
| 
 | |
| 			if clientData.Client.PrivateKey != "" {
 | |
| 				qrData, err = qrcode.Encode(config, qrcode.Medium, 512)
 | |
| 				if err != nil {
 | |
| 					// return fmt.Errorf("unable to encode qr")
 | |
| 					failedList = append(failedList, clientData.Client.Name)
 | |
| 					continue
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			userid, err := strconv.ParseInt(clientData.Client.TgUserid, 10, 64)
 | |
| 			if err != nil {
 | |
| 				// return fmt.Errorf("tg usrid is unreadable")
 | |
| 				failedList = append(failedList, clientData.Client.Name)
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			err = telegram.SendConfig(userid, clientData.Client.Name, configData, qrData, true)
 | |
| 			if err != nil {
 | |
| 				failedList = append(failedList, clientData.Client.Name)
 | |
| 				continue
 | |
| 			}
 | |
| 			time.Sleep(2 * time.Second)
 | |
| 		}
 | |
| 	} else {
 | |
| 		TgUseridToClientIDMutex.RUnlock()
 | |
| 	}
 | |
| 	return failedList
 | |
| }
 | |
| 
 | |
| func LookupEnvOrString(key string, defaultVal string) string {
 | |
| 	if val, ok := os.LookupEnv(key); ok {
 | |
| 		return val
 | |
| 	}
 | |
| 	return defaultVal
 | |
| }
 | |
| 
 | |
| func LookupEnvOrBool(key string, defaultVal bool) bool {
 | |
| 	if val, ok := os.LookupEnv(key); ok {
 | |
| 		v, err := strconv.ParseBool(val)
 | |
| 		if err != nil {
 | |
| 			fmt.Fprintf(os.Stderr, "LookupEnvOrBool[%s]: %v\n", key, err)
 | |
| 		}
 | |
| 		return v
 | |
| 	}
 | |
| 	return defaultVal
 | |
| }
 | |
| 
 | |
| func LookupEnvOrInt(key string, defaultVal int) int {
 | |
| 	if val, ok := os.LookupEnv(key); ok {
 | |
| 		v, err := strconv.Atoi(val)
 | |
| 		if err != nil {
 | |
| 			fmt.Fprintf(os.Stderr, "LookupEnvOrInt[%s]: %v\n", key, err)
 | |
| 		}
 | |
| 		return v
 | |
| 	}
 | |
| 	return defaultVal
 | |
| }
 | |
| 
 | |
| func LookupEnvOrStrings(key string, defaultVal []string) []string {
 | |
| 	if val, ok := os.LookupEnv(key); ok {
 | |
| 		return strings.Split(val, ",")
 | |
| 	}
 | |
| 	return defaultVal
 | |
| }
 | |
| 
 | |
| func LookupEnvOrFile(key string, defaultVal string) string {
 | |
| 	if val, ok := os.LookupEnv(key); ok {
 | |
| 		if file, err := os.Open(val); err == nil {
 | |
| 			var content string
 | |
| 			scanner := bufio.NewScanner(file)
 | |
| 			for scanner.Scan() {
 | |
| 				content += scanner.Text()
 | |
| 			}
 | |
| 			return content
 | |
| 		}
 | |
| 	}
 | |
| 	return defaultVal
 | |
| }
 | |
| 
 | |
| func StringFromEmbedFile(embed fs.FS, filename string) (string, error) {
 | |
| 	file, err := embed.Open(filename)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	content, err := io.ReadAll(file)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return string(content), nil
 | |
| }
 | |
| 
 | |
| func ParseLogLevel(lvl string) (log.Lvl, error) {
 | |
| 	switch strings.ToLower(lvl) {
 | |
| 	case "debug":
 | |
| 		return log.DEBUG, nil
 | |
| 	case "info":
 | |
| 		return log.INFO, nil
 | |
| 	case "warn":
 | |
| 		return log.WARN, nil
 | |
| 	case "error":
 | |
| 		return log.ERROR, nil
 | |
| 	case "off":
 | |
| 		return log.OFF, nil
 | |
| 	default:
 | |
| 		return log.DEBUG, fmt.Errorf("not a valid log level: %s", lvl)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // GetCurrentHash returns current hashes
 | |
| func GetCurrentHash(db store.IStore) (string, string) {
 | |
| 	hashClients, _ := dirhash.HashDir(path.Join(db.GetPath(), "clients"), "prefix", dirhash.Hash1)
 | |
| 	files := append([]string(nil), "prefix/global_settings.json", "prefix/interfaces.json", "prefix/keypair.json")
 | |
| 
 | |
| 	osOpen := func(name string) (io.ReadCloser, error) {
 | |
| 		return os.Open(filepath.Join(path.Join(db.GetPath(), "server"), strings.TrimPrefix(name, "prefix")))
 | |
| 	}
 | |
| 	hashServer, _ := dirhash.Hash1(files, osOpen)
 | |
| 
 | |
| 	return hashClients, hashServer
 | |
| }
 | |
| 
 | |
| func HashesChanged(db store.IStore) bool {
 | |
| 	old, _ := db.GetHashes()
 | |
| 	oldClient := old.Client
 | |
| 	oldServer := old.Server
 | |
| 	newClient, newServer := GetCurrentHash(db)
 | |
| 
 | |
| 	if oldClient != newClient {
 | |
| 		//fmt.Println("Hash for client differs")
 | |
| 		return true
 | |
| 	}
 | |
| 	if oldServer != newServer {
 | |
| 		//fmt.Println("Hash for server differs")
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| func UpdateHashes(db store.IStore) error {
 | |
| 	var clientServerHashes model.ClientServerHashes
 | |
| 	clientServerHashes.Client, clientServerHashes.Server = GetCurrentHash(db)
 | |
| 	return db.SaveHashes(clientServerHashes)
 | |
| }
 | |
| 
 | |
| func RandomString(length int) string {
 | |
| 	var seededRand = rand.New(rand.NewSource(time.Now().UnixNano()))
 | |
| 	charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
 | |
| 	b := make([]byte, length)
 | |
| 	for i := range b {
 | |
| 		b[i] = charset[seededRand.Intn(len(charset))]
 | |
| 	}
 | |
| 	return string(b)
 | |
| }
 | |
| 
 | |
| func ManagePerms(path string) error {
 | |
| 	err := os.Chmod(path, 0600)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func AddTgToClientID(userid int64, clientID string) {
 | |
| 	TgUseridToClientIDMutex.Lock()
 | |
| 	defer TgUseridToClientIDMutex.Unlock()
 | |
| 
 | |
| 	if _, ok := TgUseridToClientID[userid]; ok && TgUseridToClientID[userid] != nil {
 | |
| 		TgUseridToClientID[userid] = append(TgUseridToClientID[userid], clientID)
 | |
| 	} else {
 | |
| 		TgUseridToClientID[userid] = []string{clientID}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func UpdateTgToClientID(userid int64, clientID string) {
 | |
| 	TgUseridToClientIDMutex.Lock()
 | |
| 	defer TgUseridToClientIDMutex.Unlock()
 | |
| 
 | |
| 	// Detach clientID from any existing userid
 | |
| 	for uid, cls := range TgUseridToClientID {
 | |
| 		if cls != nil {
 | |
| 			filtered := filterStringSlice(cls, clientID)
 | |
| 			if len(filtered) > 0 {
 | |
| 				TgUseridToClientID[uid] = filtered
 | |
| 			} else {
 | |
| 				delete(TgUseridToClientID, uid)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Attach it to the new one
 | |
| 	if _, ok := TgUseridToClientID[userid]; ok && TgUseridToClientID[userid] != nil {
 | |
| 		TgUseridToClientID[userid] = append(TgUseridToClientID[userid], clientID)
 | |
| 	} else {
 | |
| 		TgUseridToClientID[userid] = []string{clientID}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func RemoveTgToClientID(clientID string) {
 | |
| 	TgUseridToClientIDMutex.Lock()
 | |
| 	defer TgUseridToClientIDMutex.Unlock()
 | |
| 
 | |
| 	// Detach clientID from any existing userid
 | |
| 	for uid, cls := range TgUseridToClientID {
 | |
| 		if cls != nil {
 | |
| 			filtered := filterStringSlice(cls, clientID)
 | |
| 			if len(filtered) > 0 {
 | |
| 				TgUseridToClientID[uid] = filtered
 | |
| 			} else {
 | |
| 				delete(TgUseridToClientID, uid)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func filterStringSlice(s []string, excludedStr string) []string {
 | |
| 	filtered := s[:0]
 | |
| 	for _, v := range s {
 | |
| 		if v != excludedStr {
 | |
| 			filtered = append(filtered, v)
 | |
| 		}
 | |
| 	}
 | |
| 	return filtered
 | |
| }
 |