515 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			515 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
| // Package mysqldb provides a MySQL storage backend for Wireguard UI
 | |
| package mysqldb
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"encoding/base64"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	rice "github.com/GeertJohan/go.rice"
 | |
| 	"github.com/go-sql-driver/mysql"
 | |
| 	"github.com/skip2/go-qrcode"
 | |
| 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 | |
| 
 | |
| 	"github.com/ngoduykhanh/wireguard-ui/model"
 | |
| 	"github.com/ngoduykhanh/wireguard-ui/util"
 | |
| )
 | |
| 
 | |
| // MySQLDB - Representation of MySQL database backend
 | |
| type MySQLDB struct {
 | |
| 	conn   *sql.DB
 | |
| 	schema string
 | |
| 	dbName string
 | |
| }
 | |
| 
 | |
| // String to split each item in array
 | |
| var arrayDelimiter = ","
 | |
| 
 | |
| // New returns pointer to MySQL database
 | |
| func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) {
 | |
| 	// Set connection config
 | |
| 	config := mysql.NewConfig()
 | |
| 	config.User = uname
 | |
| 	config.Passwd = pwd
 | |
| 	config.Net = "tcp"
 | |
| 	config.Addr = fmt.Sprintf("%s:%d", host, port)
 | |
| 	config.DBName = database
 | |
| 	config.MultiStatements = true
 | |
| 	config.ParseTime = true
 | |
| 	config.TLSConfig = tls
 | |
| 
 | |
| 	// Open connection pool
 | |
| 	conn, err := sql.Open("mysql", config.FormatDSN())
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	conn.SetConnMaxLifetime(time.Minute * 3)
 | |
| 	conn.SetMaxOpenConns(10)
 | |
| 	conn.SetMaxIdleConns(10)
 | |
| 
 | |
| 	// Test the connection
 | |
| 	if err := conn.Ping(); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// Load DB schema
 | |
| 	schema, err := templateBox.String("mysql.sql")
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ans := MySQLDB{
 | |
| 		conn:   conn,
 | |
| 		schema: schema,
 | |
| 		dbName: database,
 | |
| 	}
 | |
| 	return &ans, nil
 | |
| }
 | |
| 
 | |
| // Init initializes the database
 | |
| func (o *MySQLDB) Init() error {
 | |
| 	// Check if database is empty
 | |
| 	var databaseEmpty int
 | |
| 	err := o.conn.QueryRow(
 | |
| 		"SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?",
 | |
| 		o.dbName,
 | |
| 	).Scan(&databaseEmpty)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if !(databaseEmpty > 0) {
 | |
| 		// Initialize database
 | |
| 		// Tell the user what we're doing as this could take a while
 | |
| 		fmt.Println("Initializing database")
 | |
| 
 | |
| 		// Create database schema
 | |
| 		if _, err := o.conn.Exec(o.schema); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// servers's interface
 | |
| 		if _, err := o.conn.Exec(
 | |
| 			"INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);",
 | |
| 			util.DefaultServerAddress,
 | |
| 			util.DefaultServerPort,
 | |
| 			time.Now().UTC(),
 | |
| 		); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// server's keypair
 | |
| 		key, err := wgtypes.GeneratePrivateKey()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if _, err := o.conn.Exec(
 | |
| 			"INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);",
 | |
| 			key.String(),
 | |
| 			key.PublicKey().String(),
 | |
| 			time.Now().UTC(),
 | |
| 		); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// global settings
 | |
| 		publicInterface, err := util.GetPublicIP()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		if _, err := o.conn.Exec(
 | |
| 			"INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);",
 | |
| 			publicInterface.IPAddress,
 | |
| 			util.DefaultDNS,
 | |
| 			util.DefaultMTU,
 | |
| 			util.DefaultPersistentKeepalive,
 | |
| 			util.DefaultConfigFilePath,
 | |
| 			time.Now().UTC(),
 | |
| 		); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// user info
 | |
| 		if _, err := o.conn.Exec(
 | |
| 			"INSERT INTO users (username, password) VALUES (?, ?);",
 | |
| 			util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername),
 | |
| 			util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword),
 | |
| 		); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // GetUser func to query user info from the database
 | |
| func (o *MySQLDB) GetUser() (model.User, error) {
 | |
| 	user := model.User{}
 | |
| 	row := o.conn.QueryRow("SELECT username, password FROM users;")
 | |
| 	err := row.Scan(
 | |
| 		&user.Username,
 | |
| 		&user.Password,
 | |
| 	)
 | |
| 	return user, err
 | |
| }
 | |
| 
 | |
| // GetGlobalSettings func to query global settings from the database
 | |
| func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) {
 | |
| 	settings := model.GlobalSetting{}
 | |
| 	var dnsServers string
 | |
| 
 | |
| 	row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;")
 | |
| 	// Can't use ScanStruct here as doesn't know how to handle
 | |
| 	// dns_servers list. Instead we must populate struct it manually.
 | |
| 	err := row.Scan(
 | |
| 		&settings.EndpointAddress,
 | |
| 		&dnsServers,
 | |
| 		&settings.MTU,
 | |
| 		&settings.PersistentKeepalive,
 | |
| 		&settings.ConfigFilePath,
 | |
| 		&settings.UpdatedAt,
 | |
| 	)
 | |
| 	settings.DNSServers = strings.Split(dnsServers, arrayDelimiter)
 | |
| 	return settings, err
 | |
| }
 | |
| 
 | |
| // GetServer func to query Server setting from the database
 | |
| func (o *MySQLDB) GetServer() (model.Server, error) {
 | |
| 	server := model.Server{}
 | |
| 
 | |
| 	// Get interface
 | |
| 	serverInterface := model.ServerInterface{}
 | |
| 	var addresses string
 | |
| 
 | |
| 	row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;")
 | |
| 	err := row.Scan(
 | |
| 		&addresses,
 | |
| 		&serverInterface.ListenPort,
 | |
| 		&serverInterface.UpdatedAt,
 | |
| 		&serverInterface.PostUp,
 | |
| 		&serverInterface.PostDown,
 | |
| 	)
 | |
| 	serverInterface.Addresses = strings.Split(addresses, arrayDelimiter)
 | |
| 	if err != nil {
 | |
| 		return server, err
 | |
| 	}
 | |
| 
 | |
| 	// Get keypair
 | |
| 	serverKeyPair := model.ServerKeypair{}
 | |
| 	if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;").
 | |
| 		Scan(
 | |
| 			&serverKeyPair.PrivateKey,
 | |
| 			&serverKeyPair.PublicKey,
 | |
| 			&serverKeyPair.UpdatedAt,
 | |
| 		); err != nil {
 | |
| 		return server, err
 | |
| 	}
 | |
| 
 | |
| 	// create Server object and return
 | |
| 	server.Interface = &serverInterface
 | |
| 	server.KeyPair = &serverKeyPair
 | |
| 	return server, nil
 | |
| }
 | |
| 
 | |
| // GetClients func to query Client settings from the database
 | |
| func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
 | |
| 	var clients []model.ClientData
 | |
| 
 | |
| 	rows, err := o.conn.Query("SELECT * FROM clients;")
 | |
| 	if err != nil {
 | |
| 		return clients, err
 | |
| 	}
 | |
| 
 | |
| 	for rows.Next() {
 | |
| 		client := model.Client{}
 | |
| 		clientData := model.ClientData{}
 | |
| 		var allocatedIPs string
 | |
| 		var allowedIPs string
 | |
| 		var extraAllowedIPs string
 | |
| 
 | |
| 		// Get client info
 | |
| 		if err := rows.Scan(
 | |
| 			&client.ID,
 | |
| 			&client.PrivateKey,
 | |
| 			&client.PublicKey,
 | |
| 			&client.PresharedKey,
 | |
| 			&client.Name,
 | |
| 			&client.Email,
 | |
| 			&allocatedIPs,
 | |
| 			&allowedIPs,
 | |
| 			&extraAllowedIPs,
 | |
| 			&client.UseServerDNS,
 | |
| 			&client.Enabled,
 | |
| 			&client.CreatedAt,
 | |
| 			&client.UpdatedAt,
 | |
| 		); err != nil {
 | |
| 			return clients, err
 | |
| 		}
 | |
| 		client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
 | |
| 		client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
 | |
| 		client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
 | |
| 
 | |
| 		// generate client qrcode image in base64
 | |
| 		if hasQRCode && client.PrivateKey != "" {
 | |
| 			server, _ := o.GetServer()
 | |
| 			globalSettings, _ := o.GetGlobalSettings()
 | |
| 
 | |
| 			png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
 | |
| 			if err == nil {
 | |
| 				clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
 | |
| 			} else {
 | |
| 				fmt.Print("Cannot generate QR code: ", err)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// create the list of clients and their qrcode data
 | |
| 		clientData.Client = &client
 | |
| 		clients = append(clients, clientData)
 | |
| 	}
 | |
| 
 | |
| 	return clients, nil
 | |
| }
 | |
| 
 | |
| // GetClientByID func to query Clients by ID from the database
 | |
| func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) {
 | |
| 	client := model.Client{}
 | |
| 	clientData := model.ClientData{}
 | |
| 	var allocatedIPs string
 | |
| 	var allowedIPs string
 | |
| 	var extraAllowedIPs string
 | |
| 
 | |
| 	// read client info
 | |
| 	if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan(
 | |
| 		&client.ID,
 | |
| 		&client.PrivateKey,
 | |
| 		&client.PublicKey,
 | |
| 		&client.PresharedKey,
 | |
| 		&client.Name,
 | |
| 		&client.Email,
 | |
| 		&allocatedIPs,
 | |
| 		&allowedIPs,
 | |
| 		&extraAllowedIPs,
 | |
| 		&client.UseServerDNS,
 | |
| 		&client.Enabled,
 | |
| 		&client.CreatedAt,
 | |
| 		&client.UpdatedAt,
 | |
| 	); err != nil {
 | |
| 		return clientData, err
 | |
| 	}
 | |
| 	client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
 | |
| 	client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
 | |
| 	client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
 | |
| 
 | |
| 	// generate client qrcode image in base64
 | |
| 	if hasQRCode && client.PrivateKey != "" {
 | |
| 		server, _ := o.GetServer()
 | |
| 		globalSettings, _ := o.GetGlobalSettings()
 | |
| 
 | |
| 		png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
 | |
| 		if err == nil {
 | |
| 			clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
 | |
| 		} else {
 | |
| 			fmt.Print("Cannot generate QR code: ", err)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	clientData.Client = &client
 | |
| 
 | |
| 	return clientData, nil
 | |
| }
 | |
| 
 | |
| // SaveClient func saves client to database
 | |
| func (o *MySQLDB) SaveClient(client model.Client) error {
 | |
| 	// If client doesn't exist, create a record, else update existing record
 | |
| 	querySet := `
 | |
| 		SET
 | |
| 			@id = ?,
 | |
| 			@private_key = ?,
 | |
| 			@public_key = ?,
 | |
| 			@preshared_key = ?,
 | |
| 			@name = ?,
 | |
| 			@email = ?,
 | |
| 			@allocated_ips = ?,
 | |
| 			@allowed_ips = ?,
 | |
| 			@extra_allowed_ips = ?,
 | |
| 			@use_server_dns = ?,
 | |
| 			@enabled = ?,
 | |
| 			@created_at = ?,
 | |
| 			@updated_at = ?;`
 | |
| 	queryInsert := `
 | |
| 		INSERT INTO clients(
 | |
| 			id,
 | |
| 			private_key,
 | |
| 			public_key,
 | |
| 			preshared_key,
 | |
| 			NAME,
 | |
| 			email,
 | |
| 			allocated_ips,
 | |
| 			allowed_ips,
 | |
| 			extra_allowed_ips,
 | |
| 			use_server_dns,
 | |
| 			enabled,
 | |
| 			created_at,
 | |
| 			updated_at
 | |
| 		)
 | |
| 		VALUES(
 | |
| 			@id,
 | |
| 			@private_key,
 | |
| 			@public_key,
 | |
| 			@preshared_key,
 | |
| 			@name,
 | |
| 			@email,
 | |
| 			@allocated_ips,
 | |
| 			@allowed_ips,
 | |
| 			@extra_allowed_ips,
 | |
| 			@use_server_dns,
 | |
| 			@enabled,
 | |
| 			@created_at,
 | |
| 			@updated_at
 | |
| 		)
 | |
| 		ON DUPLICATE KEY
 | |
| 		UPDATE
 | |
| 			id = @id,
 | |
| 			private_key = @private_key,
 | |
| 			public_key = @public_key,
 | |
| 			preshared_key = @preshared_key,
 | |
| 			NAME = @name,
 | |
| 			email = @email,
 | |
| 			allocated_ips = @allocated_ips,
 | |
| 			allowed_ips = @allowed_ips,
 | |
| 			extra_allowed_ips = @extra_allowed_ips,
 | |
| 			use_server_dns = @use_server_dns,
 | |
| 			enabled = @enabled,
 | |
| 			created_at = @created_at,
 | |
| 			updated_at = @updated_at;`
 | |
| 
 | |
| 	tx, err := o.conn.Begin()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	// set values
 | |
| 	if _, err := tx.Exec(
 | |
| 		querySet,
 | |
| 		client.ID,
 | |
| 		client.PrivateKey,
 | |
| 		client.PublicKey,
 | |
| 		client.PresharedKey,
 | |
| 		client.Name,
 | |
| 		client.Email,
 | |
| 		strings.Join(client.AllocatedIPs, arrayDelimiter),
 | |
| 		strings.Join(client.AllowedIPs, arrayDelimiter),
 | |
| 		strings.Join(client.ExtraAllowedIPs, arrayDelimiter),
 | |
| 		client.UseServerDNS,
 | |
| 		client.Enabled,
 | |
| 		client.CreatedAt,
 | |
| 		client.UpdatedAt,
 | |
| 	); err != nil {
 | |
| 		if rbErr := tx.Rollback(); rbErr != nil {
 | |
| 			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
 | |
| 		}
 | |
| 
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// insert or update row
 | |
| 	if _, err := tx.Exec(queryInsert); err != nil {
 | |
| 		if rbErr := tx.Rollback(); rbErr != nil {
 | |
| 			return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
 | |
| 		}
 | |
| 
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return tx.Commit()
 | |
| }
 | |
| 
 | |
| // DeleteClient func deletes client from the database
 | |
| func (o *MySQLDB) DeleteClient(clientID string) error {
 | |
| 	if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // SaveServerInterface func saves a server interface to database
 | |
| func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error {
 | |
| 	// No need for ON DUPLICATE KEY UPDATE as only ever 1 record
 | |
| 	query := `
 | |
| 		UPDATE
 | |
| 			interfaces
 | |
| 		SET
 | |
| 			addresses = ?,
 | |
| 			listen_port = ?,
 | |
| 			updated_at = ?,
 | |
| 			post_up = ?,
 | |
| 			post_down = ?
 | |
| 		WHERE
 | |
| 			id = 1;`
 | |
| 
 | |
| 	_, err := o.conn.Exec(
 | |
| 		query,
 | |
| 		strings.Join(serverInterface.Addresses, arrayDelimiter),
 | |
| 		serverInterface.ListenPort,
 | |
| 		serverInterface.UpdatedAt,
 | |
| 		serverInterface.PostUp,
 | |
| 		serverInterface.PostDown,
 | |
| 	)
 | |
| 
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // SaveServerKeyPair func saves a server keypair to database
 | |
| func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
 | |
| 	query := `
 | |
| 		UPDATE
 | |
| 			keypair
 | |
| 		SET
 | |
| 			private_key = ?,
 | |
| 			public_key = ?,
 | |
| 			updated_at = ?
 | |
| 		WHERE
 | |
| 			id = 1;`
 | |
| 
 | |
| 	_, err := o.conn.Exec(
 | |
| 		query,
 | |
| 		serverKeyPair.PrivateKey,
 | |
| 		serverKeyPair.PublicKey,
 | |
| 		serverKeyPair.UpdatedAt,
 | |
| 	)
 | |
| 
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| // SaveGlobalSettings saves global settings to database
 | |
| func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
 | |
| 	query := `
 | |
| 		UPDATE
 | |
| 			global_settings
 | |
| 		SET
 | |
| 			endpoint_address = ?,
 | |
| 			dns_servers = ?,
 | |
| 			mtu = ?,
 | |
| 			persistent_keepalive = ?,
 | |
| 			config_file_path = ?,
 | |
| 			updated_at = ?
 | |
| 		WHERE
 | |
| 			id = 1;`
 | |
| 
 | |
| 	_, err := o.conn.Exec(
 | |
| 		query,
 | |
| 		globalSettings.EndpointAddress,
 | |
| 		strings.Join(globalSettings.DNSServers, arrayDelimiter),
 | |
| 		globalSettings.MTU,
 | |
| 		globalSettings.PersistentKeepalive,
 | |
| 		globalSettings.ConfigFilePath,
 | |
| 		globalSettings.UpdatedAt,
 | |
| 	)
 | |
| 
 | |
| 	return err
 | |
| }
 |