wireguard-ui/store/mysqldb/mysqldb.go

515 lines
12 KiB
Go

// Package mysqldb provides a MySQL storage backend for Wireguard UI
package mysqldb
import (
"database/sql"
"encoding/base64"
"fmt"
"strings"
"time"
rice "github.com/GeertJohan/go.rice"
"github.com/go-sql-driver/mysql"
"github.com/skip2/go-qrcode"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/util"
)
// MySQLDB - Representation of MySQL database backend
type MySQLDB struct {
conn *sql.DB
schema string
dbName string
}
// String to split each item in array
var arrayDelimiter = ","
// New returns pointer to MySQL database
func New(uname string, pwd string, host string, port int, database string, tls string, templateBox *rice.Box) (*MySQLDB, error) {
// Set connection config
config := mysql.NewConfig()
config.User = uname
config.Passwd = pwd
config.Net = "tcp"
config.Addr = fmt.Sprintf("%s:%d", host, port)
config.DBName = database
config.MultiStatements = true
config.ParseTime = true
config.TLSConfig = tls
// Open connection pool
conn, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}
conn.SetConnMaxLifetime(time.Minute * 3)
conn.SetMaxOpenConns(10)
conn.SetMaxIdleConns(10)
// Test the connection
if err := conn.Ping(); err != nil {
return nil, err
}
// Load DB schema
schema, err := templateBox.String("mysql.sql")
if err != nil {
return nil, err
}
ans := MySQLDB{
conn: conn,
schema: schema,
dbName: database,
}
return &ans, nil
}
// Init initializes the database
func (o *MySQLDB) Init() error {
// Check if database is empty
var databaseEmpty int
err := o.conn.QueryRow(
"SELECT COUNT(DISTINCT `table_name`) FROM `information_schema`.`columns` WHERE `table_schema` = ?",
o.dbName,
).Scan(&databaseEmpty)
if err != nil {
return err
}
if !(databaseEmpty > 0) {
// Initialize database
// Tell the user what we're doing as this could take a while
fmt.Println("Initializing database")
// Create database schema
if _, err := o.conn.Exec(o.schema); err != nil {
return err
}
// servers's interface
if _, err := o.conn.Exec(
"INSERT INTO interfaces (addresses, listen_port, updated_at) VALUES (?, ?, ?);",
util.DefaultServerAddress,
util.DefaultServerPort,
time.Now().UTC(),
); err != nil {
return err
}
// server's keypair
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return err
}
if _, err := o.conn.Exec(
"INSERT INTO keypair (private_key, public_key, updated_at) VALUES (?, ?, ?);",
key.String(),
key.PublicKey().String(),
time.Now().UTC(),
); err != nil {
return err
}
// global settings
publicInterface, err := util.GetPublicIP()
if err != nil {
return err
}
if _, err := o.conn.Exec(
"INSERT INTO global_settings (endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at) VALUES (?, ?, ?, ?, ?, ?);",
publicInterface.IPAddress,
util.DefaultDNS,
util.DefaultMTU,
util.DefaultPersistentKeepalive,
util.DefaultConfigFilePath,
time.Now().UTC(),
); err != nil {
return err
}
// user info
if _, err := o.conn.Exec(
"INSERT INTO users (username, password) VALUES (?, ?);",
util.GetCredVar(util.UsernameEnvVar, util.DefaultUsername),
util.GetCredVar(util.PasswordEnvVar, util.DefaultPassword),
); err != nil {
return err
}
}
return nil
}
// GetUser func to query user info from the database
func (o *MySQLDB) GetUser() (model.User, error) {
user := model.User{}
row := o.conn.QueryRow("SELECT username, password FROM users;")
err := row.Scan(
&user.Username,
&user.Password,
)
return user, err
}
// GetGlobalSettings func to query global settings from the database
func (o *MySQLDB) GetGlobalSettings() (model.GlobalSetting, error) {
settings := model.GlobalSetting{}
var dnsServers string
row := o.conn.QueryRow("SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, config_file_path, updated_at FROM global_settings;")
// Can't use ScanStruct here as doesn't know how to handle
// dns_servers list. Instead we must populate struct it manually.
err := row.Scan(
&settings.EndpointAddress,
&dnsServers,
&settings.MTU,
&settings.PersistentKeepalive,
&settings.ConfigFilePath,
&settings.UpdatedAt,
)
settings.DNSServers = strings.Split(dnsServers, arrayDelimiter)
return settings, err
}
// GetServer func to query Server setting from the database
func (o *MySQLDB) GetServer() (model.Server, error) {
server := model.Server{}
// Get interface
serverInterface := model.ServerInterface{}
var addresses string
row := o.conn.QueryRow("SELECT addresses, listen_port, updated_at, post_up, post_down FROM interfaces;")
err := row.Scan(
&addresses,
&serverInterface.ListenPort,
&serverInterface.UpdatedAt,
&serverInterface.PostUp,
&serverInterface.PostDown,
)
serverInterface.Addresses = strings.Split(addresses, arrayDelimiter)
if err != nil {
return server, err
}
// Get keypair
serverKeyPair := model.ServerKeypair{}
if err := o.conn.QueryRow("SELECT private_key, public_key, updated_at FROM keypair;").
Scan(
&serverKeyPair.PrivateKey,
&serverKeyPair.PublicKey,
&serverKeyPair.UpdatedAt,
); err != nil {
return server, err
}
// create Server object and return
server.Interface = &serverInterface
server.KeyPair = &serverKeyPair
return server, nil
}
// GetClients func to query Client settings from the database
func (o *MySQLDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
var clients []model.ClientData
rows, err := o.conn.Query("SELECT * FROM clients;")
if err != nil {
return clients, err
}
for rows.Next() {
client := model.Client{}
clientData := model.ClientData{}
var allocatedIPs string
var allowedIPs string
var extraAllowedIPs string
// Get client info
if err := rows.Scan(
&client.ID,
&client.PrivateKey,
&client.PublicKey,
&client.PresharedKey,
&client.Name,
&client.Email,
&allocatedIPs,
&allowedIPs,
&extraAllowedIPs,
&client.UseServerDNS,
&client.Enabled,
&client.CreatedAt,
&client.UpdatedAt,
); err != nil {
return clients, err
}
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
// create the list of clients and their qrcode data
clientData.Client = &client
clients = append(clients, clientData)
}
return clients, nil
}
// GetClientByID func to query Clients by ID from the database
func (o *MySQLDB) GetClientByID(clientID string, hasQRCode bool) (model.ClientData, error) {
client := model.Client{}
clientData := model.ClientData{}
var allocatedIPs string
var allowedIPs string
var extraAllowedIPs string
// read client info
if err := o.conn.QueryRow("SELECT * FROM clients WHERE id = ?;", clientID).Scan(
&client.ID,
&client.PrivateKey,
&client.PublicKey,
&client.PresharedKey,
&client.Name,
&client.Email,
&allocatedIPs,
&allowedIPs,
&extraAllowedIPs,
&client.UseServerDNS,
&client.Enabled,
&client.CreatedAt,
&client.UpdatedAt,
); err != nil {
return clientData, err
}
client.AllocatedIPs = strings.Split(allocatedIPs, arrayDelimiter)
client.AllowedIPs = strings.Split(allowedIPs, arrayDelimiter)
client.ExtraAllowedIPs = strings.Split(extraAllowedIPs, arrayDelimiter)
// generate client qrcode image in base64
if hasQRCode && client.PrivateKey != "" {
server, _ := o.GetServer()
globalSettings, _ := o.GetGlobalSettings()
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
if err == nil {
clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png))
} else {
fmt.Print("Cannot generate QR code: ", err)
}
}
clientData.Client = &client
return clientData, nil
}
// SaveClient func saves client to database
func (o *MySQLDB) SaveClient(client model.Client) error {
// If client doesn't exist, create a record, else update existing record
querySet := `
SET
@id = ?,
@private_key = ?,
@public_key = ?,
@preshared_key = ?,
@name = ?,
@email = ?,
@allocated_ips = ?,
@allowed_ips = ?,
@extra_allowed_ips = ?,
@use_server_dns = ?,
@enabled = ?,
@created_at = ?,
@updated_at = ?;`
queryInsert := `
INSERT INTO clients(
id,
private_key,
public_key,
preshared_key,
NAME,
email,
allocated_ips,
allowed_ips,
extra_allowed_ips,
use_server_dns,
enabled,
created_at,
updated_at
)
VALUES(
@id,
@private_key,
@public_key,
@preshared_key,
@name,
@email,
@allocated_ips,
@allowed_ips,
@extra_allowed_ips,
@use_server_dns,
@enabled,
@created_at,
@updated_at
)
ON DUPLICATE KEY
UPDATE
id = @id,
private_key = @private_key,
public_key = @public_key,
preshared_key = @preshared_key,
NAME = @name,
email = @email,
allocated_ips = @allocated_ips,
allowed_ips = @allowed_ips,
extra_allowed_ips = @extra_allowed_ips,
use_server_dns = @use_server_dns,
enabled = @enabled,
created_at = @created_at,
updated_at = @updated_at;`
tx, err := o.conn.Begin()
if err != nil {
return err
}
// set values
if _, err := tx.Exec(
querySet,
client.ID,
client.PrivateKey,
client.PublicKey,
client.PresharedKey,
client.Name,
client.Email,
strings.Join(client.AllocatedIPs, arrayDelimiter),
strings.Join(client.AllowedIPs, arrayDelimiter),
strings.Join(client.ExtraAllowedIPs, arrayDelimiter),
client.UseServerDNS,
client.Enabled,
client.CreatedAt,
client.UpdatedAt,
); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
// insert or update row
if _, err := tx.Exec(queryInsert); err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
}
return err
}
return tx.Commit()
}
// DeleteClient func deletes client from the database
func (o *MySQLDB) DeleteClient(clientID string) error {
if _, err := o.conn.Exec("DELETE FROM clients WHERE id=?;", clientID); err != nil {
return err
}
return nil
}
// SaveServerInterface func saves a server interface to database
func (o *MySQLDB) SaveServerInterface(serverInterface model.ServerInterface) error {
// No need for ON DUPLICATE KEY UPDATE as only ever 1 record
query := `
UPDATE
interfaces
SET
addresses = ?,
listen_port = ?,
updated_at = ?,
post_up = ?,
post_down = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
strings.Join(serverInterface.Addresses, arrayDelimiter),
serverInterface.ListenPort,
serverInterface.UpdatedAt,
serverInterface.PostUp,
serverInterface.PostDown,
)
return err
}
// SaveServerKeyPair func saves a server keypair to database
func (o *MySQLDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
query := `
UPDATE
keypair
SET
private_key = ?,
public_key = ?,
updated_at = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
serverKeyPair.PrivateKey,
serverKeyPair.PublicKey,
serverKeyPair.UpdatedAt,
)
return err
}
// SaveGlobalSettings saves global settings to database
func (o *MySQLDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
query := `
UPDATE
global_settings
SET
endpoint_address = ?,
dns_servers = ?,
mtu = ?,
persistent_keepalive = ?,
config_file_path = ?,
updated_at = ?
WHERE
id = 1;`
_, err := o.conn.Exec(
query,
globalSettings.EndpointAddress,
strings.Join(globalSettings.DNSServers, arrayDelimiter),
globalSettings.MTU,
globalSettings.PersistentKeepalive,
globalSettings.ConfigFilePath,
globalSettings.UpdatedAt,
)
return err
}