517 lines
16 KiB
Go
517 lines
16 KiB
Go
package sqlitedb
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
"github.com/skip2/go-qrcode"
|
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
|
|
"github.com/DigitalTolk/wireguard-ui/model"
|
|
"github.com/DigitalTolk/wireguard-ui/util"
|
|
)
|
|
|
|
const (
|
|
userColumns = "username, email, display_name, COALESCE(oidc_sub, ''), admin, created_at, updated_at"
|
|
qrCodeDataURIPrefix = "data:image/png;base64,"
|
|
)
|
|
|
|
//go:embed schema.sql
|
|
var schemaFS embed.FS
|
|
|
|
// SqliteDB implements store.IStore using SQLite
|
|
type SqliteDB struct {
|
|
db *sql.DB
|
|
dbPath string
|
|
}
|
|
|
|
// New creates a new SqliteDB instance
|
|
func New(dbPath string) (*SqliteDB, error) {
|
|
// ensure parent directory exists
|
|
dir := filepath.Dir(dbPath)
|
|
if err := os.MkdirAll(dir, 0750); err != nil {
|
|
return nil, fmt.Errorf("cannot create database directory: %w", err)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", dbPath+"?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=ON")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot open database: %w", err)
|
|
}
|
|
|
|
// apply schema
|
|
schema, err := schemaFS.ReadFile("schema.sql")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot read schema: %w", err)
|
|
}
|
|
if _, err := db.Exec(string(schema)); err != nil {
|
|
return nil, fmt.Errorf("cannot apply schema: %w", err)
|
|
}
|
|
|
|
return &SqliteDB{db: db, dbPath: dbPath}, nil
|
|
}
|
|
|
|
// Init initializes the database with default values if they don't exist
|
|
func (o *SqliteDB) Init() error {
|
|
// server interface
|
|
var ifaceCount int
|
|
o.db.QueryRow("SELECT COUNT(*) FROM server_interface").Scan(&ifaceCount)
|
|
if ifaceCount == 0 {
|
|
addresses := util.LookupEnvOrStrings(util.ServerAddressesEnvVar, []string{util.DefaultServerAddress})
|
|
listenPort := util.LookupEnvOrInt(util.ServerListenPortEnvVar, util.DefaultServerPort)
|
|
postUp := util.LookupEnvOrString(util.ServerPostUpScriptEnvVar, "")
|
|
postDown := util.LookupEnvOrString(util.ServerPostDownScriptEnvVar, "")
|
|
addrJSON, _ := json.Marshal(addresses)
|
|
_, err := o.db.Exec(
|
|
`INSERT INTO server_interface (id, addresses, listen_port, post_up, post_down, updated_at) VALUES (1, ?, ?, ?, ?, ?)`,
|
|
string(addrJSON), listenPort, postUp, postDown, time.Now().UTC(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot init server interface: %w", err)
|
|
}
|
|
}
|
|
|
|
// server keypair
|
|
var kpCount int
|
|
o.db.QueryRow("SELECT COUNT(*) FROM server_keypair").Scan(&kpCount)
|
|
if kpCount == 0 {
|
|
key, err := wgtypes.GeneratePrivateKey()
|
|
if err != nil {
|
|
return fmt.Errorf("cannot generate server keypair: %w", err)
|
|
}
|
|
_, err = o.db.Exec(
|
|
`INSERT INTO server_keypair (id, private_key, public_key, updated_at) VALUES (1, ?, ?, ?)`,
|
|
key.String(), key.PublicKey().String(), time.Now().UTC(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot init server keypair: %w", err)
|
|
}
|
|
}
|
|
|
|
// global settings
|
|
var gsCount int
|
|
o.db.QueryRow("SELECT COUNT(*) FROM global_settings").Scan(&gsCount)
|
|
if gsCount == 0 {
|
|
endpointAddress := util.LookupEnvOrString(util.EndpointAddressEnvVar, "")
|
|
if endpointAddress == "" {
|
|
publicInterface, err := util.GetPublicIP()
|
|
if err != nil {
|
|
return fmt.Errorf("cannot detect public IP: %w", err)
|
|
}
|
|
endpointAddress = publicInterface.IPAddress
|
|
}
|
|
dnsServers := util.LookupEnvOrStrings(util.DNSEnvVar, []string{util.DefaultDNS})
|
|
dnsJSON, _ := json.Marshal(dnsServers)
|
|
_, err := o.db.Exec(
|
|
`INSERT INTO global_settings (id, endpoint_address, dns_servers, mtu, persistent_keepalive, firewall_mark, "table", config_file_path, updated_at)
|
|
VALUES (1, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
endpointAddress,
|
|
string(dnsJSON),
|
|
util.LookupEnvOrInt(util.MTUEnvVar, util.DefaultMTU),
|
|
util.LookupEnvOrInt(util.PersistentKeepaliveEnvVar, util.DefaultPersistentKeepalive),
|
|
util.LookupEnvOrString(util.FirewallMarkEnvVar, util.DefaultFirewallMark),
|
|
util.LookupEnvOrString(util.TableEnvVar, util.DefaultTable),
|
|
util.LookupEnvOrString(util.ConfigFilePathEnvVar, util.DefaultConfigFilePath),
|
|
time.Now().UTC(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("cannot init global settings: %w", err)
|
|
}
|
|
}
|
|
|
|
// hashes
|
|
var hashCount int
|
|
o.db.QueryRow("SELECT COUNT(*) FROM hashes").Scan(&hashCount)
|
|
if hashCount == 0 {
|
|
o.db.Exec(`INSERT INTO hashes (id, client, server) VALUES (1, 'none', 'none')`)
|
|
}
|
|
|
|
// init caches (first OIDC login auto-provisions admin user)
|
|
users, err := o.GetUsers()
|
|
if err == nil {
|
|
util.DBUsersToCRC32Mutex.Lock()
|
|
for _, user := range users {
|
|
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
|
|
}
|
|
util.DBUsersToCRC32Mutex.Unlock()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetUsers returns all users
|
|
func (o *SqliteDB) GetUsers() ([]model.User, error) {
|
|
rows, err := o.db.Query("SELECT " + userColumns + " FROM users")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var users []model.User
|
|
for rows.Next() {
|
|
var u model.User
|
|
if err := rows.Scan(&u.Username, &u.Email, &u.DisplayName, &u.OIDCSub, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, u)
|
|
}
|
|
return users, rows.Err()
|
|
}
|
|
|
|
// GetUserByName returns a single user by username
|
|
func (o *SqliteDB) GetUserByName(username string) (model.User, error) {
|
|
var u model.User
|
|
err := o.db.QueryRow(
|
|
"SELECT "+userColumns+" FROM users WHERE username = ?",
|
|
username,
|
|
).Scan(&u.Username, &u.Email, &u.DisplayName, &u.OIDCSub, &u.Admin, &u.CreatedAt, &u.UpdatedAt)
|
|
if err != nil {
|
|
return u, err
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// SaveUser creates or updates a user
|
|
func (o *SqliteDB) SaveUser(user model.User) error {
|
|
now := time.Now().UTC()
|
|
if user.UpdatedAt.IsZero() {
|
|
user.UpdatedAt = now
|
|
}
|
|
if user.CreatedAt.IsZero() {
|
|
user.CreatedAt = now
|
|
}
|
|
|
|
_, err := o.db.Exec(
|
|
`INSERT INTO users (username, email, display_name, oidc_sub, admin, created_at, updated_at)
|
|
VALUES (?, ?, ?, NULLIF(?, ''), ?, ?, ?)
|
|
ON CONFLICT(username) DO UPDATE SET
|
|
email = excluded.email,
|
|
display_name = excluded.display_name,
|
|
oidc_sub = excluded.oidc_sub,
|
|
admin = excluded.admin,
|
|
updated_at = excluded.updated_at`,
|
|
user.Username, user.Email, user.DisplayName, user.OIDCSub, user.Admin, user.CreatedAt, user.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
util.DBUsersToCRC32Mutex.Lock()
|
|
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
|
|
util.DBUsersToCRC32Mutex.Unlock()
|
|
return nil
|
|
}
|
|
|
|
// DeleteUser removes a user by username
|
|
func (o *SqliteDB) DeleteUser(username string) error {
|
|
util.DBUsersToCRC32Mutex.Lock()
|
|
delete(util.DBUsersToCRC32, username)
|
|
util.DBUsersToCRC32Mutex.Unlock()
|
|
_, err := o.db.Exec("DELETE FROM users WHERE username = ?", username)
|
|
return err
|
|
}
|
|
|
|
// GetGlobalSettings returns global WireGuard settings
|
|
func (o *SqliteDB) GetGlobalSettings() (model.GlobalSetting, error) {
|
|
var gs model.GlobalSetting
|
|
var dnsJSON string
|
|
err := o.db.QueryRow(
|
|
`SELECT endpoint_address, dns_servers, mtu, persistent_keepalive, firewall_mark, "table", config_file_path, updated_at
|
|
FROM global_settings WHERE id = 1`,
|
|
).Scan(&gs.EndpointAddress, &dnsJSON, &gs.MTU, &gs.PersistentKeepalive, &gs.FirewallMark, &gs.Table, &gs.ConfigFilePath, &gs.UpdatedAt)
|
|
if err != nil {
|
|
return gs, err
|
|
}
|
|
json.Unmarshal([]byte(dnsJSON), &gs.DNSServers)
|
|
return gs, nil
|
|
}
|
|
|
|
// GetServer returns the server config (interface + keypair)
|
|
func (o *SqliteDB) GetServer() (model.Server, error) {
|
|
server := model.Server{}
|
|
|
|
// interface
|
|
iface := model.ServerInterface{}
|
|
var addrJSON string
|
|
err := o.db.QueryRow(
|
|
"SELECT addresses, listen_port, post_up, pre_down, post_down, updated_at FROM server_interface WHERE id = 1",
|
|
).Scan(&addrJSON, &iface.ListenPort, &iface.PostUp, &iface.PreDown, &iface.PostDown, &iface.UpdatedAt)
|
|
if err != nil {
|
|
return server, fmt.Errorf("cannot read server interface: %w", err)
|
|
}
|
|
json.Unmarshal([]byte(addrJSON), &iface.Addresses)
|
|
server.Interface = &iface
|
|
|
|
// keypair
|
|
kp := model.ServerKeypair{}
|
|
err = o.db.QueryRow(
|
|
"SELECT private_key, public_key, updated_at FROM server_keypair WHERE id = 1",
|
|
).Scan(&kp.PrivateKey, &kp.PublicKey, &kp.UpdatedAt)
|
|
if err != nil {
|
|
return server, fmt.Errorf("cannot read server keypair: %w", err)
|
|
}
|
|
server.KeyPair = &kp
|
|
|
|
return server, nil
|
|
}
|
|
|
|
// GetClients returns all clients, optionally with QR codes
|
|
func (o *SqliteDB) GetClients(hasQRCode bool) ([]model.ClientData, error) {
|
|
rows, err := o.db.Query(
|
|
`SELECT id, private_key, public_key, preshared_key, name, email, telegram_userid,
|
|
subnet_ranges, allocated_ips, allowed_ips, extra_allowed_ips,
|
|
endpoint, additional_notes, use_server_dns, enabled, created_at, updated_at
|
|
FROM clients`,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// fetch server/settings once outside the loop for QR generation
|
|
var server model.Server
|
|
var globalSettings model.GlobalSetting
|
|
if hasQRCode {
|
|
server, _ = o.GetServer()
|
|
globalSettings, _ = o.GetGlobalSettings()
|
|
}
|
|
|
|
var clients []model.ClientData
|
|
for rows.Next() {
|
|
client, err := scanClientFrom(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clientData := model.ClientData{Client: &client}
|
|
|
|
if hasQRCode && client.PrivateKey != "" {
|
|
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
|
|
if err == nil {
|
|
clientData.QRCode = qrCodeDataURIPrefix + base64.StdEncoding.EncodeToString(png)
|
|
}
|
|
}
|
|
|
|
clients = append(clients, clientData)
|
|
}
|
|
return clients, rows.Err()
|
|
}
|
|
|
|
// GetClientByID returns a single client by ID
|
|
func (o *SqliteDB) GetClientByID(clientID string, qrCodeSettings model.QRCodeSettings) (model.ClientData, error) {
|
|
clientData := model.ClientData{}
|
|
|
|
row := o.db.QueryRow(
|
|
`SELECT id, private_key, public_key, preshared_key, name, email, telegram_userid,
|
|
subnet_ranges, allocated_ips, allowed_ips, extra_allowed_ips,
|
|
endpoint, additional_notes, use_server_dns, enabled, created_at, updated_at
|
|
FROM clients WHERE id = ?`, clientID,
|
|
)
|
|
|
|
client, err := scanClientFrom(row)
|
|
if err != nil {
|
|
return clientData, err
|
|
}
|
|
|
|
if qrCodeSettings.Enabled && client.PrivateKey != "" {
|
|
server, _ := o.GetServer()
|
|
globalSettings, _ := o.GetGlobalSettings()
|
|
if !qrCodeSettings.IncludeDNS {
|
|
globalSettings.DNSServers = []string{}
|
|
}
|
|
if !qrCodeSettings.IncludeMTU {
|
|
globalSettings.MTU = 0
|
|
}
|
|
png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256)
|
|
if err == nil {
|
|
clientData.QRCode = qrCodeDataURIPrefix + base64.StdEncoding.EncodeToString(png)
|
|
}
|
|
}
|
|
|
|
clientData.Client = &client
|
|
return clientData, nil
|
|
}
|
|
|
|
// SaveClient creates or updates a client
|
|
func (o *SqliteDB) SaveClient(client model.Client) error {
|
|
subnetJSON, _ := json.Marshal(client.SubnetRanges)
|
|
allocJSON, _ := json.Marshal(client.AllocatedIPs)
|
|
allowJSON, _ := json.Marshal(client.AllowedIPs)
|
|
extraJSON, _ := json.Marshal(client.ExtraAllowedIPs)
|
|
|
|
_, err := o.db.Exec(
|
|
`INSERT INTO clients (id, private_key, public_key, preshared_key, name, email, telegram_userid,
|
|
subnet_ranges, allocated_ips, allowed_ips, extra_allowed_ips,
|
|
endpoint, additional_notes, use_server_dns, enabled, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
private_key = excluded.private_key,
|
|
public_key = excluded.public_key,
|
|
preshared_key = excluded.preshared_key,
|
|
name = excluded.name,
|
|
email = excluded.email,
|
|
telegram_userid = excluded.telegram_userid,
|
|
subnet_ranges = excluded.subnet_ranges,
|
|
allocated_ips = excluded.allocated_ips,
|
|
allowed_ips = excluded.allowed_ips,
|
|
extra_allowed_ips = excluded.extra_allowed_ips,
|
|
endpoint = excluded.endpoint,
|
|
additional_notes = excluded.additional_notes,
|
|
use_server_dns = excluded.use_server_dns,
|
|
enabled = excluded.enabled,
|
|
updated_at = excluded.updated_at`,
|
|
client.ID, client.PrivateKey, client.PublicKey, client.PresharedKey,
|
|
client.Name, client.Email, client.TgUserid,
|
|
string(subnetJSON), string(allocJSON), string(allowJSON), string(extraJSON),
|
|
client.Endpoint, client.AdditionalNotes, client.UseServerDNS, client.Enabled,
|
|
client.CreatedAt, client.UpdatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// DeleteClient removes a client by ID
|
|
func (o *SqliteDB) DeleteClient(clientID string) error {
|
|
_, err := o.db.Exec("DELETE FROM clients WHERE id = ?", clientID)
|
|
return err
|
|
}
|
|
|
|
// SaveServerInterface updates the server interface config
|
|
func (o *SqliteDB) SaveServerInterface(serverInterface model.ServerInterface) error {
|
|
addrJSON, _ := json.Marshal(serverInterface.Addresses)
|
|
_, err := o.db.Exec(
|
|
`UPDATE server_interface SET addresses = ?, listen_port = ?, post_up = ?, pre_down = ?, post_down = ?, updated_at = ? WHERE id = 1`,
|
|
string(addrJSON), serverInterface.ListenPort, serverInterface.PostUp, serverInterface.PreDown, serverInterface.PostDown, serverInterface.UpdatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// SaveServerKeyPair updates the server keypair
|
|
func (o *SqliteDB) SaveServerKeyPair(serverKeyPair model.ServerKeypair) error {
|
|
_, err := o.db.Exec(
|
|
`UPDATE server_keypair SET private_key = ?, public_key = ?, updated_at = ? WHERE id = 1`,
|
|
serverKeyPair.PrivateKey, serverKeyPair.PublicKey, serverKeyPair.UpdatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// SaveGlobalSettings updates global settings
|
|
func (o *SqliteDB) SaveGlobalSettings(globalSettings model.GlobalSetting) error {
|
|
dnsJSON, _ := json.Marshal(globalSettings.DNSServers)
|
|
_, err := o.db.Exec(
|
|
`UPDATE global_settings SET endpoint_address = ?, dns_servers = ?, mtu = ?, persistent_keepalive = ?,
|
|
firewall_mark = ?, "table" = ?, config_file_path = ?, updated_at = ? WHERE id = 1`,
|
|
globalSettings.EndpointAddress, string(dnsJSON), globalSettings.MTU, globalSettings.PersistentKeepalive,
|
|
globalSettings.FirewallMark, globalSettings.Table, globalSettings.ConfigFilePath, globalSettings.UpdatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// GetAllocatedIPs returns all IP addresses allocated to clients and server
|
|
func (o *SqliteDB) GetAllocatedIPs(excludeClientID string) ([]string, error) {
|
|
allocatedIPs := make([]string, 0)
|
|
|
|
// server addresses
|
|
var addrJSON string
|
|
err := o.db.QueryRow("SELECT addresses FROM server_interface WHERE id = 1").Scan(&addrJSON)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var serverAddrs []string
|
|
json.Unmarshal([]byte(addrJSON), &serverAddrs)
|
|
for _, cidr := range serverAddrs {
|
|
ip, _, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allocatedIPs = append(allocatedIPs, ip.String())
|
|
}
|
|
|
|
// client addresses
|
|
rows, err := o.db.Query("SELECT allocated_ips FROM clients WHERE id != ?", excludeClientID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var ipsJSON string
|
|
if err := rows.Scan(&ipsJSON); err != nil {
|
|
return nil, err
|
|
}
|
|
var ips []string
|
|
json.Unmarshal([]byte(ipsJSON), &ips)
|
|
for _, cidr := range ips {
|
|
ip, _, err := net.ParseCIDR(cidr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allocatedIPs = append(allocatedIPs, ip.String())
|
|
}
|
|
}
|
|
|
|
return allocatedIPs, rows.Err()
|
|
}
|
|
|
|
// GetPath returns the database file path
|
|
func (o *SqliteDB) GetPath() string {
|
|
return filepath.Dir(o.dbPath)
|
|
}
|
|
|
|
// GetHashes returns stored hashes
|
|
func (o *SqliteDB) GetHashes() (model.ClientServerHashes, error) {
|
|
var h model.ClientServerHashes
|
|
err := o.db.QueryRow("SELECT client, server FROM hashes WHERE id = 1").Scan(&h.Client, &h.Server)
|
|
return h, err
|
|
}
|
|
|
|
// SaveHashes updates stored hashes
|
|
func (o *SqliteDB) SaveHashes(hashes model.ClientServerHashes) error {
|
|
_, err := o.db.Exec("UPDATE hashes SET client = ?, server = ? WHERE id = 1", hashes.Client, hashes.Server)
|
|
return err
|
|
}
|
|
|
|
// DB returns the underlying sql.DB for direct access (e.g., audit logs)
|
|
func (o *SqliteDB) DB() *sql.DB {
|
|
return o.db
|
|
}
|
|
|
|
// GetUserByOIDCSub returns a user by their OIDC subject identifier
|
|
func (o *SqliteDB) GetUserByOIDCSub(sub string) (model.User, error) {
|
|
var u model.User
|
|
err := o.db.QueryRow(
|
|
"SELECT "+userColumns+" FROM users WHERE oidc_sub = ?", sub,
|
|
).Scan(&u.Username, &u.Email, &u.DisplayName, &u.OIDCSub, &u.Admin, &u.CreatedAt, &u.UpdatedAt)
|
|
return u, err
|
|
}
|
|
|
|
// scanner is satisfied by both *sql.Rows and *sql.Row
|
|
type scanner interface {
|
|
Scan(dest ...interface{}) error
|
|
}
|
|
|
|
func scanClientFrom(s scanner) (model.Client, error) {
|
|
var c model.Client
|
|
var subnetJSON, allocJSON, allowJSON, extraJSON string
|
|
err := s.Scan(
|
|
&c.ID, &c.PrivateKey, &c.PublicKey, &c.PresharedKey,
|
|
&c.Name, &c.Email, &c.TgUserid,
|
|
&subnetJSON, &allocJSON, &allowJSON, &extraJSON,
|
|
&c.Endpoint, &c.AdditionalNotes, &c.UseServerDNS, &c.Enabled,
|
|
&c.CreatedAt, &c.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return c, err
|
|
}
|
|
json.Unmarshal([]byte(subnetJSON), &c.SubnetRanges)
|
|
json.Unmarshal([]byte(allocJSON), &c.AllocatedIPs)
|
|
json.Unmarshal([]byte(allowJSON), &c.AllowedIPs)
|
|
json.Unmarshal([]byte(extraJSON), &c.ExtraAllowedIPs)
|
|
return c, nil
|
|
}
|