515 lines
12 KiB
Go
515 lines
12 KiB
Go
// Package mysqldb provides a MySQL storage backend for Wireguard UI
|
|
package mysqldb
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
rice "github.com/GeertJohan/go.rice"
|
|
"github.com/go-sql-driver/mysql"
|
|
"github.com/skip2/go-qrcode"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"github.com/ngoduykhanh/wireguard-ui/model"
|
|
"github.com/ngoduykhanh/wireguard-ui/util"
|
|
)
|
|
|
|
// MySQLDB - Representation of MySQL database backend
|
|
type MySQLDB struct {
|
|
conn *sql.DB
|
|
schema string
|
|
dbName string
|
|
}
|
|
|
|
// String to split each item in array
|
|
var arrayDelimiter = ","
|
|
|
|
// New returns pointer to MySQL database
|
|
func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) {
|
|
// Set connection config
|
|
config := mysql.NewConfig()
|
|
config.User = uname
|
|
config.Passwd = pwd
|
|
config.Net = "tcp"
|
|
config.Addr = fmt.Sprintf("%s:%d", host, port)
|
|
config.DBName = database
|
|
config.MultiStatements = true
|
|
config.ParseTime = true
|
|
config.TLSConfig = tls
|
|
|
|
// Open connection pool
|
|
conn, err := sql.Open("mysql", config.FormatDSN())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn.SetConnMaxLifetime(time.Minute * 3)
|
|
conn.SetMaxOpenConns(10)
|
|
conn.SetMaxIdleConns(10)
|
|
|
|
// Test the connection
|
|
if err := conn.Ping(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Load DB schema
|
|
schema, err := templateBox.String("mysql.sql")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
ans := MySQLDB{
|
|
conn: conn,
|
|
schema: schema,
|
|
dbName: database,
|
|
}
|
|
return &ans, nil
|
|
}
|
|
|
|
// Init initializes the database
|
|
func (o *MySQLDB) Init() error {
|
|
// Check if database is empty
|
|
var databaseEmpty int
|
|
err := o.conn.QueryRow(
|
|
"SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?",
|
|
o.dbName,
|
|
).Scan(&databaseEmpty)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !(databaseEmpty > 0) {
|
|
// Initialize database
|
|
// Tell the user what we're doing as this could take a while
|
|
fmt.Println("Initializing database")
|
|
|
|
// Create database schema
|
|
if _, err := o.conn.Exec(o.schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// servers's interface
|
|
if _, err := o.conn.Exec(
|
|
"INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);",
|
|
util.DefaultServerAddress,
|
|
util.DefaultServerPort,
|
|
time.Now().UTC(),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
// server's keypair
|
|
key, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := o.conn.Exec(
|
|
"INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);",
|
|
key.String(),
|
|
key.PublicKey().String(),
|
|
time.Now().UTC(),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
// global settings
|
|
publicInterface, err := util.GetPublicIP()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err := o.conn.Exec(
|
|
"INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);",
|
|
publicInterface.IPAddress,
|
|
util.DefaultDNS,
|
|
util.DefaultMTU,
|
|
util.DefaultPersistentKeepalive,
|
|
util.DefaultConfigFilePath,
|
|
time.Now().UTC(),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
|
|
// user info
|
|
if _, err := o.conn.Exec(
|
|
"INSERT INTO users (username, password) VALUES (?, ?);",
|
|
util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername),
|
|
util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword),
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetUser func to query user info from the database
|
|
func (o *MySQLDB) GetUser() (model.User, error) {
|
|
user := model.User{}
|
|
row := o.conn.QueryRow("SELECT username, password FROM users;")
|
|
err := row.Scan(
|
|
&user.Username,
|
|
&user.Password,
|
|
)
|
|
return user, err
|
|
}
|
|
|
|
// GetGlobalSettings func to query global settings from the database
|
|
func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) {
|
|
settings := model.GlobalSetting{}
|
|
var dnsServers string
|
|
|
|
row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;")
|
|
// Can't use ScanStruct here as doesn't know how to handle
|
|
// dns_servers list. Instead we must populate struct it manually.
|
|
err := row.Scan(
|
|
&settings.EndpointAddress,
|
|
&dnsServers,
|
|
&settings.MTU,
|
|
&settings.PersistentKeepalive,
|
|
&settings.ConfigFilePath,
|
|
&settings.UpdatedAt,
|
|
)
|
|
settings.DNSServers = strings.Split(dnsServers, arrayDelimiter)
|
|
return settings, err
|
|
}
|
|
|
|
// GetServer func to query Server setting from the database
|
|
func (o *MySQLDB) GetServer() (model.Server, error) {
|
|
server := model.Server{}
|
|
|
|
// Get interface
|
|
serverInterface := model.ServerInterface{}
|
|
var addresses string
|
|
|
|
row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;")
|
|
err := row.Scan(
|
|
&addresses,
|
|
&serverInterface.ListenPort,
|
|
&serverInterface.UpdatedAt,
|
|
&serverInterface.PostUp,
|
|
&serverInterface.PostDown,
|
|
)
|
|
serverInterface.Addresses = strings.Split(addresses, arrayDelimiter)
|
|
if err != nil {
|
|
return server, err
|
|
}
|
|
|
|
// Get keypair
|
|
serverKeyPair := model.ServerKeypair{}
|
|
if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;").
|
|
Scan(
|
|
&serverKeyPair.PrivateKey,
|
|
&serverKeyPair.PublicKey,
|
|
&serverKeyPair.UpdatedAt,
|
|
); err != nil {
|
|
return server, err
|
|
}
|
|
|
|
// create Server object and return
|
|
server.Interface = &serverInterface
|
|
server.KeyPair = &serverKeyPair
|
|
return server, nil
|
|
}
|
|
|
|
// GetClients func to query Client settings from the database
|
|
func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
|
|
var clients []model.ClientData
|
|
|
|
rows, err := o.conn.Query("SELECT * FROM clients;")
|
|
if err != nil {
|
|
return clients, err
|
|
}
|
|
|
|
for rows.Next() {
|
|
client := model.Client{}
|
|
clientData := model.ClientData{}
|
|
var allocatedIPs string
|
|
var allowedIPs string
|
|
var extraAllowedIPs string
|
|
|
|
// Get client info
|
|
if err := rows.Scan(
|
|
&client.ID,
|
|
&client.PrivateKey,
|
|
&client.PublicKey,
|
|
&client.PresharedKey,
|
|
&client.Name,
|
|
&client.Email,
|
|
&allocatedIPs,
|
|
&allowedIPs,
|
|
&extraAllowedIPs,
|
|
&client.UseServerDNS,
|
|
&client.Enabled,
|
|
&client.CreatedAt,
|
|
&client.UpdatedAt,
|
|
); err != nil {
|
|
return clients, err
|
|
}
|
|
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
|
|
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
|
|
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
|
|
|
|
// generate client qrcode image in base64
|
|
if hasQRCode && client.PrivateKey != "" {
|
|
server, _ := o.GetServer()
|
|
globalSettings, _ := o.GetGlobalSettings()
|
|
|
|
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
|
|
if err == nil {
|
|
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
|
|
} else {
|
|
fmt.Print("Cannot generate QR code: ", err)
|
|
}
|
|
}
|
|
|
|
// create the list of clients and their qrcode data
|
|
clientData.Client = &client
|
|
clients = append(clients, clientData)
|
|
}
|
|
|
|
return clients, nil
|
|
}
|
|
|
|
// GetClientByID func to query Clients by ID from the database
|
|
func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) {
|
|
client := model.Client{}
|
|
clientData := model.ClientData{}
|
|
var allocatedIPs string
|
|
var allowedIPs string
|
|
var extraAllowedIPs string
|
|
|
|
// read client info
|
|
if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan(
|
|
&client.ID,
|
|
&client.PrivateKey,
|
|
&client.PublicKey,
|
|
&client.PresharedKey,
|
|
&client.Name,
|
|
&client.Email,
|
|
&allocatedIPs,
|
|
&allowedIPs,
|
|
&extraAllowedIPs,
|
|
&client.UseServerDNS,
|
|
&client.Enabled,
|
|
&client.CreatedAt,
|
|
&client.UpdatedAt,
|
|
); err != nil {
|
|
return clientData, err
|
|
}
|
|
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
|
|
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
|
|
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
|
|
|
|
// generate client qrcode image in base64
|
|
if hasQRCode && client.PrivateKey != "" {
|
|
server, _ := o.GetServer()
|
|
globalSettings, _ := o.GetGlobalSettings()
|
|
|
|
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
|
|
if err == nil {
|
|
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
|
|
} else {
|
|
fmt.Print("Cannot generate QR code: ", err)
|
|
}
|
|
}
|
|
|
|
clientData.Client = &client
|
|
|
|
return clientData, nil
|
|
}
|
|
|
|
// SaveClient func saves client to database
|
|
func (o *MySQLDB) SaveClient(client model.Client) error {
|
|
// If client doesn't exist, create a record, else update existing record
|
|
querySet := `
|
|
SET
|
|
@id = ?,
|
|
@private_key = ?,
|
|
@public_key = ?,
|
|
@preshared_key = ?,
|
|
@name = ?,
|
|
@email = ?,
|
|
@allocated_ips = ?,
|
|
@allowed_ips = ?,
|
|
@extra_allowed_ips = ?,
|
|
@use_server_dns = ?,
|
|
@enabled = ?,
|
|
@created_at = ?,
|
|
@updated_at = ?;`
|
|
queryInsert := `
|
|
INSERT INTO clients(
|
|
id,
|
|
private_key,
|
|
public_key,
|
|
preshared_key,
|
|
NAME,
|
|
email,
|
|
allocated_ips,
|
|
allowed_ips,
|
|
extra_allowed_ips,
|
|
use_server_dns,
|
|
enabled,
|
|
created_at,
|
|
updated_at
|
|
)
|
|
VALUES(
|
|
@id,
|
|
@private_key,
|
|
@public_key,
|
|
@preshared_key,
|
|
@name,
|
|
@email,
|
|
@allocated_ips,
|
|
@allowed_ips,
|
|
@extra_allowed_ips,
|
|
@use_server_dns,
|
|
@enabled,
|
|
@created_at,
|
|
@updated_at
|
|
)
|
|
ON DUPLICATE KEY
|
|
UPDATE
|
|
id = @id,
|
|
private_key = @private_key,
|
|
public_key = @public_key,
|
|
preshared_key = @preshared_key,
|
|
NAME = @name,
|
|
email = @email,
|
|
allocated_ips = @allocated_ips,
|
|
allowed_ips = @allowed_ips,
|
|
extra_allowed_ips = @extra_allowed_ips,
|
|
use_server_dns = @use_server_dns,
|
|
enabled = @enabled,
|
|
created_at = @created_at,
|
|
updated_at = @updated_at;`
|
|
|
|
tx, err := o.conn.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// set values
|
|
if _, err := tx.Exec(
|
|
querySet,
|
|
client.ID,
|
|
client.PrivateKey,
|
|
client.PublicKey,
|
|
client.PresharedKey,
|
|
client.Name,
|
|
client.Email,
|
|
strings.Join(client.AllocatedIPs, arrayDelimiter),
|
|
strings.Join(client.AllowedIPs, arrayDelimiter),
|
|
strings.Join(client.ExtraAllowedIPs, arrayDelimiter),
|
|
client.UseServerDNS,
|
|
client.Enabled,
|
|
client.CreatedAt,
|
|
client.UpdatedAt,
|
|
); err != nil {
|
|
if rbErr := tx.Rollback(); rbErr != nil {
|
|
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
// insert or update row
|
|
if _, err := tx.Exec(queryInsert); err != nil {
|
|
if rbErr := tx.Rollback(); rbErr != nil {
|
|
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// DeleteClient func deletes client from the database
|
|
func (o *MySQLDB) DeleteClient(clientID string) error {
|
|
if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// SaveServerInterface func saves a server interface to database
|
|
func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error {
|
|
// No need for ON DUPLICATE KEY UPDATE as only ever 1 record
|
|
query := `
|
|
UPDATE
|
|
interfaces
|
|
SET
|
|
addresses = ?,
|
|
listen_port = ?,
|
|
updated_at = ?,
|
|
post_up = ?,
|
|
post_down = ?
|
|
WHERE
|
|
id = 1;`
|
|
|
|
_, err := o.conn.Exec(
|
|
query,
|
|
strings.Join(serverInterface.Addresses, arrayDelimiter),
|
|
serverInterface.ListenPort,
|
|
serverInterface.UpdatedAt,
|
|
serverInterface.PostUp,
|
|
serverInterface.PostDown,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// SaveServerKeyPair func saves a server keypair to database
|
|
func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
|
|
query := `
|
|
UPDATE
|
|
keypair
|
|
SET
|
|
private_key = ?,
|
|
public_key = ?,
|
|
updated_at = ?
|
|
WHERE
|
|
id = 1;`
|
|
|
|
_, err := o.conn.Exec(
|
|
query,
|
|
serverKeyPair.PrivateKey,
|
|
serverKeyPair.PublicKey,
|
|
serverKeyPair.UpdatedAt,
|
|
)
|
|
|
|
return err
|
|
}
|
|
|
|
// SaveGlobalSettings saves global settings to database
|
|
func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
|
|
query := `
|
|
UPDATE
|
|
global_settings
|
|
SET
|
|
endpoint_address = ?,
|
|
dns_servers = ?,
|
|
mtu = ?,
|
|
persistent_keepalive = ?,
|
|
config_file_path = ?,
|
|
updated_at = ?
|
|
WHERE
|
|
id = 1;`
|
|
|
|
_, err := o.conn.Exec(
|
|
query,
|
|
globalSettings.EndpointAddress,
|
|
strings.Join(globalSettings.DNSServers, arrayDelimiter),
|
|
globalSettings.MTU,
|
|
globalSettings.PersistentKeepalive,
|
|
globalSettings.ConfigFilePath,
|
|
globalSettings.UpdatedAt,
|
|
)
|
|
|
|
return err
|
|
}
|