feat: add GitHub auth configuration and session foundation

This commit is contained in:
devcxl 2026-04-04 07:34:17 +08:00
parent 9d7948417b
commit bdbb9250b8
12 changed files with 1157 additions and 67 deletions

View File

@ -13,8 +13,6 @@ import (
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/labstack/gommon/log"
"github.com/rs/xid"
@ -90,43 +88,10 @@ func Login(db store.IStore) echo.HandlerFunc {
}
if userCorrect && passwordCorrect {
ageMax := 0
if rememberMe {
ageMax = 86400 * 7
if err := establishAuthenticatedSession(c, dbuser, rememberMe); err != nil {
log.Errorf("Failed to establish authenticated session: %v", err)
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Failed to establish session"})
}
cookiePath := util.GetCookiePath()
sess, _ := session.Get("session", c)
sess.Options = &sessions.Options{
Path: cookiePath,
MaxAge: ageMax,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
// set session_token
tokenUID := xid.New().String()
now := time.Now().UTC().Unix()
sess.Values["username"] = dbuser.Username
sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser)
sess.Values["admin"] = dbuser.Admin
sess.Values["session_token"] = tokenUID
sess.Values["max_age"] = ageMax
sess.Values["created_at"] = now
sess.Values["updated_at"] = now
sess.Save(c.Request(), c.Response())
// set session_token in cookie
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = tokenUID
cookie.MaxAge = ageMax
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"})
}
@ -588,6 +553,7 @@ func UpdateClient(db store.IStore) echo.HandlerFunc {
}
client := *clientData.Client
allocatedIPs, err := util.GetAllocatedIPs("")
check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, _client.AllocatedIPs)
if !check {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)})

View File

@ -8,6 +8,9 @@ import (
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/rs/xid"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/util"
)
@ -247,3 +250,47 @@ func clearSession(c echo.Context) {
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
}
func establishAuthenticatedSession(c echo.Context, user model.User, rememberMe bool) error {
sess, err := session.Get("session", c)
if err != nil {
return err
}
ageMax := 0
if rememberMe {
ageMax = 86400 * 7
}
cookiePath := util.GetCookiePath()
sess.Options = &sessions.Options{
Path: cookiePath,
MaxAge: ageMax,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}
tokenUID := xid.New().String()
now := time.Now().UTC().Unix()
sess.Values["username"] = user.Username
sess.Values["user_hash"] = util.GetDBUserCRC32(user)
sess.Values["admin"] = user.Admin
sess.Values["session_token"] = tokenUID
sess.Values["max_age"] = ageMax
sess.Values["created_at"] = now
sess.Values["updated_at"] = now
sess.Save(c.Request(), c.Response())
cookie := new(http.Cookie)
cookie.Name = "session_token"
cookie.Path = cookiePath
cookie.Value = tokenUID
cookie.MaxAge = ageMax
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteLaxMode
c.SetCookie(cookie)
return nil
}

View File

@ -7,4 +7,7 @@ type User struct {
// PasswordHash takes precedence over Password.
PasswordHash string `json:"password_hash"`
Admin bool `json:"admin"`
AuthSource string `json:"auth_source"`
AuthSubject string `json:"auth_subject"`
DisplayName string `json:"display_name"`
}

View File

@ -0,0 +1,374 @@
package jsondb
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/sdomino/scribble"
)
func TestGetUserByAuthIdentity_FindsGitHubUser(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
githubUser := model.User{
Username: "github-user",
Password: "",
PasswordHash: "",
Admin: false,
AuthSource: "github",
AuthSubject: "123456",
DisplayName: "GitHub User",
}
if err := db.SaveUser(githubUser); err != nil {
t.Fatalf("failed to save user: %v", err)
}
found, err := db.GetUserByAuthIdentity("github", "123456")
if err != nil {
t.Fatalf("GetUserByAuthIdentity failed: %v", err)
}
if found.Username != "github-user" {
t.Errorf("expected username 'github-user', got %q", found.Username)
}
if found.AuthSource != "github" {
t.Errorf("expected auth_source 'github', got %q", found.AuthSource)
}
if found.AuthSubject != "123456" {
t.Errorf("expected auth_subject '123456', got %q", found.AuthSubject)
}
}
func TestGetUserByAuthIdentity_NotFound(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
_, err = db.GetUserByAuthIdentity("github", "nonexistent")
if err == nil {
t.Error("expected error for nonexistent auth identity")
}
}
func TestGetUserByAuthIdentity_LocalUserNotFound(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
localUser := model.User{
Username: "local-user",
Password: "secret",
PasswordHash: "hash",
Admin: false,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Local User",
}
if err := db.SaveUser(localUser); err != nil {
t.Fatalf("failed to save user: %v", err)
}
_, err = db.GetUserByAuthIdentity("local", "")
if err == nil {
t.Error("expected error for local user queried by empty auth identity")
}
}
func TestReplaceUser_RenameKeepsOnlyNewRecord(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
oldUser := model.User{
Username: "oldname",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Old Name",
}
if err := db.SaveUser(oldUser); err != nil {
t.Fatalf("failed to save user: %v", err)
}
newUser := model.User{
Username: "newname",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "New Name",
}
if err := db.ReplaceUser("oldname", newUser); err != nil {
t.Fatalf("ReplaceUser failed: %v", err)
}
_, err = db.GetUserByName("oldname")
if err == nil {
t.Error("old username should not exist after rename")
}
_, err = db.GetUserByName("newname")
if err != nil {
t.Errorf("new username should exist after rename: %v", err)
}
users, err := db.GetUsers()
if err != nil {
t.Fatalf("GetUsers failed: %v", err)
}
if len(users) != 1 {
t.Errorf("expected exactly 1 user after rename, got %d", len(users))
}
}
func TestReplaceUser_SameNameUpdatesRecord(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
originalUser := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Original",
}
if err := db.SaveUser(originalUser); err != nil {
t.Fatalf("failed to save user: %v", err)
}
updatedUser := model.User{
Username: "testuser",
Password: "newpassword",
PasswordHash: "newhash",
Admin: false,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Updated",
}
if err := db.ReplaceUser("testuser", updatedUser); err != nil {
t.Fatalf("ReplaceUser failed: %v", err)
}
found, err := db.GetUserByName("testuser")
if err != nil {
t.Fatalf("GetUserByName failed: %v", err)
}
if found.Admin != false {
t.Errorf("expected admin=false, got %v", found.Admin)
}
if found.DisplayName != "Updated" {
t.Errorf("expected display_name='Updated', got %q", found.DisplayName)
}
users, err := db.GetUsers()
if err != nil {
t.Fatalf("GetUsers failed: %v", err)
}
if len(users) != 1 {
t.Errorf("expected exactly 1 user after update, got %d", len(users))
}
}
func TestReplaceUser_MigratesLegacyLocalUser(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
legacyUser := model.User{
Username: "legacyuser",
Password: "oldpassword",
PasswordHash: "oldhash",
Admin: false,
AuthSource: "",
AuthSubject: "",
DisplayName: "",
}
if err := db.SaveUser(legacyUser); err != nil {
t.Fatalf("failed to save legacy user: %v", err)
}
migratedUser := model.User{
Username: "legacyuser",
Password: "",
PasswordHash: "",
Admin: false,
AuthSource: "github",
AuthSubject: "987654",
DisplayName: "Migrated User",
}
if err := db.ReplaceUser("legacyuser", migratedUser); err != nil {
t.Fatalf("ReplaceUser failed: %v", err)
}
found, err := db.GetUserByAuthIdentity("github", "987654")
if err != nil {
t.Fatalf("GetUserByAuthIdentity failed: %v", err)
}
if found.Username != "legacyuser" {
t.Errorf("expected username 'legacyuser', got %q", found.Username)
}
if found.AuthSource != "github" {
t.Errorf("expected auth_source 'github', got %q", found.AuthSource)
}
if found.AuthSubject != "987654" {
t.Errorf("expected auth_subject '987654', got %q", found.AuthSubject)
}
if found.Password != "" || found.PasswordHash != "" {
t.Errorf("expected empty password fields for migrated user, got password=%q, password_hash=%q", found.Password, found.PasswordHash)
}
}
func TestReplaceUser_WriteNewFails_KeepsOldRecord(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
existingUser := model.User{
Username: "existing",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Existing User",
}
if err := db.SaveUser(existingUser); err != nil {
t.Fatalf("failed to save user: %v", err)
}
userDir := filepath.Join(tmpDir, "users")
os.Chmod(userDir, 0000)
newUser := model.User{
Username: "newuser",
Password: "password",
PasswordHash: "hash",
Admin: false,
AuthSource: "github",
AuthSubject: "123456",
DisplayName: "New User",
}
err = db.ReplaceUser("existing", newUser)
os.Chmod(userDir, 0700)
if err == nil {
t.Error("expected error when writing new record fails")
}
_, err = db.GetUserByName("existing")
if err != nil {
t.Errorf("old record should still exist after failed replace: %v", err)
}
_, err = db.GetUserByName("newuser")
if err == nil {
t.Error("new record should not exist after failed replace")
}
}
func TestReplaceUser_DeleteOldFails_AttemptsRollbackBeforeManualCleanupError(t *testing.T) {
tmpDir := t.TempDir()
db, err := New(tmpDir)
if err != nil {
t.Fatalf("failed to create db: %v", err)
}
oldUser := model.User{
Username: "olduser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Old User",
}
if err := db.SaveUser(oldUser); err != nil {
t.Fatalf("failed to save user: %v", err)
}
newUser := model.User{
Username: "newuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "New User",
}
origDeleteUserFile := deleteUserFile
deleteUserFile = func(conn *scribble.Driver, username string) error {
if username == "olduser" {
return fmt.Errorf("simulated delete failure for olduser")
}
if username == "newuser" {
return fmt.Errorf("simulated rollback failure for newuser")
}
return origDeleteUserFile(conn, username)
}
defer func() { deleteUserFile = origDeleteUserFile }()
err = db.ReplaceUser("olduser", newUser)
if err == nil {
t.Fatal("expected error when delete old fails and rollback also fails")
}
if !strings.Contains(err.Error(), "manual cleanup required") {
t.Errorf("error should contain 'manual cleanup required', got: %v", err)
}
if !strings.Contains(err.Error(), "rollback failed") {
t.Errorf("error should contain 'rollback failed', got: %v", err)
}
_, err = db.GetUserByName("olduser")
if err != nil {
t.Errorf("old record should still exist: %v", err)
}
_, err = db.GetUserByName("newuser")
if err != nil {
t.Errorf("new record should still exist after failed rollback: %v", err)
}
}

View File

@ -16,6 +16,10 @@ import (
"github.com/ngoduykhanh/wireguard-ui/util"
)
var deleteUserFile = func(conn *scribble.Driver, username string) error {
return conn.Delete("users", username)
}
type JsonDB struct {
conn *scribble.Driver
dbPath string
@ -216,7 +220,78 @@ func (o *JsonDB) SaveUser(user model.User) error {
// DeleteUser func to remove user from the database
func (o *JsonDB) DeleteUser(username string) error {
delete(util.DBUsersToCRC32, username)
return o.conn.Delete("users", username)
return deleteUserFile(o.conn, username)
}
// GetUserByAuthIdentity func to get a user by auth source and subject
func (o *JsonDB) GetUserByAuthIdentity(authSource, authSubject string) (model.User, error) {
if authSubject == "" {
return model.User{}, fmt.Errorf("auth subject cannot be empty for auth source %q", authSource)
}
users, err := o.GetUsers()
if err != nil {
return model.User{}, err
}
for _, user := range users {
if user.AuthSource == authSource && user.AuthSubject == authSubject {
return user, nil
}
}
return model.User{}, fmt.Errorf("user not found for auth source %q and subject %q", authSource, authSubject)
}
// ReplaceUser safely replaces or renames a user record.
// If oldUsername == user.Username, it performs an update.
// If they differ, it writes the new record first, then deletes the old one.
// If the old record deletion fails after the new record was written, the new
// record remains and an error indicating manual cleanup is returned.
func (o *JsonDB) ReplaceUser(oldUsername string, user model.User) error {
userPath := path.Join(path.Join(o.dbPath, "users"), user.Username+".json")
oldUserPath := path.Join(path.Join(o.dbPath, "users"), oldUsername+".json")
isRename := oldUsername != user.Username
if isRename {
_, err := os.Stat(oldUserPath)
if err != nil {
return fmt.Errorf("old user record does not exist: %w", err)
}
_, err = os.Stat(userPath)
if err == nil {
return fmt.Errorf("cannot rename to %q: a record with that name already exists", user.Username)
}
}
if err := o.conn.Write("users", user.Username, user); err != nil {
return fmt.Errorf("failed to write new user record: %w", err)
}
if err := util.ManagePerms(userPath); err != nil {
os.Remove(userPath)
return fmt.Errorf("failed to set permissions on new user record: %w", err)
}
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
if isRename {
delete(util.DBUsersToCRC32, oldUsername)
if err := deleteUserFile(o.conn, oldUsername); err != nil {
rollbackErr := deleteUserFile(o.conn, user.Username)
if rollbackErr != nil {
return fmt.Errorf("failed to delete old user record %q (new record %q written): %v; rollback failed: %v; manual cleanup required: remove %s and ensure %s exists", oldUsername, user.Username, err, rollbackErr, userPath, oldUserPath)
}
delete(util.DBUsersToCRC32, user.Username)
os.RemoveAll(userPath)
return fmt.Errorf("failed to delete old user record %q (new record %q was written and rolled back): %v", oldUsername, user.Username, err)
}
}
return nil
}
// GetGlobalSettings func to query global settings from the database

View File

@ -8,7 +8,9 @@ type IStore interface {
Init() error
GetUsers() ([]model.User, error)
GetUserByName(username string) (model.User, error)
GetUserByAuthIdentity(authSource, authSubject string) (model.User, error)
SaveUser(user model.User) error
ReplaceUser(oldUsername string, user model.User) error
DeleteUser(username string) error
GetGlobalSettings() (model.GlobalSetting, error)
GetServer() (model.Server, error)

124
util/auth_config.go Normal file
View File

@ -0,0 +1,124 @@
package util
import (
"errors"
"fmt"
"os"
"strings"
)
type AuthMethod string
const (
AuthMethodLocal AuthMethod = "local"
AuthMethodGitHub AuthMethod = "github"
)
const (
AuthMethodEnvVar = "WGUI_AUTH_METHOD"
GitHubClientIDEnvVar = "WGUI_GITHUB_CLIENT_ID"
GitHubClientSecretEnvVar = "WGUI_GITHUB_CLIENT_SECRET"
GitHubClientSecretFileEnvVar = "WGUI_GITHUB_CLIENT_SECRET_FILE"
GitHubRedirectURLEnvVar = "WGUI_GITHUB_REDIRECT_URL"
GitHubAllowedUsersEnvVar = "WGUI_GITHUB_ALLOWED_USERS"
GitHubAllowedOrgsEnvVar = "WGUI_GITHUB_ALLOWED_ORGS"
GitHubAdminUsersEnvVar = "WGUI_GITHUB_ADMIN_USERS"
)
type GitHubAuthConfig struct {
ClientID string
ClientSecret string
ClientSecretFile string
RedirectURL string
AllowedUsers []string
AllowedOrgs []string
AdminUsers []string
}
func ParseAuthMethod() AuthMethod {
method := LookupEnvOrString(AuthMethodEnvVar, "local")
switch strings.ToLower(method) {
case "github":
return AuthMethodGitHub
default:
return AuthMethodLocal
}
}
func ParseGitHubAuthConfig() GitHubAuthConfig {
clientID := LookupEnvOrString(GitHubClientIDEnvVar, "")
clientSecret := LookupEnvOrString(GitHubClientSecretEnvVar, "")
clientSecretFile := LookupEnvOrString(GitHubClientSecretFileEnvVar, "")
redirectURL := LookupEnvOrString(GitHubRedirectURLEnvVar, "")
allowedUsersRaw := LookupEnvOrString(GitHubAllowedUsersEnvVar, "")
allowedOrgsRaw := LookupEnvOrString(GitHubAllowedOrgsEnvVar, "")
adminUsersRaw := LookupEnvOrString(GitHubAdminUsersEnvVar, "")
var allowedUsers, allowedOrgs, adminUsers []string
if allowedUsersRaw != "" {
allowedUsers = NormalizeUsernames(strings.Split(allowedUsersRaw, ","))
}
if allowedOrgsRaw != "" {
allowedOrgs = NormalizeUsernames(strings.Split(allowedOrgsRaw, ","))
}
if adminUsersRaw != "" {
adminUsers = NormalizeUsernames(strings.Split(adminUsersRaw, ","))
}
return GitHubAuthConfig{
ClientID: clientID,
ClientSecret: clientSecret,
ClientSecretFile: clientSecretFile,
RedirectURL: redirectURL,
AllowedUsers: allowedUsers,
AllowedOrgs: allowedOrgs,
AdminUsers: adminUsers,
}
}
func ValidateGitHubAuthConfig(config GitHubAuthConfig) error {
if config.ClientID == "" {
return errors.New("github auth config: client_id is required")
}
if config.ClientSecret == "" && config.ClientSecretFile == "" {
return errors.New("github auth config: client_secret is required")
}
if config.ClientSecretFile != "" {
secret, err := os.ReadFile(config.ClientSecretFile)
if err != nil {
return fmt.Errorf("github auth config: reading client_secret_file %q: %w", config.ClientSecretFile, err)
}
if len(strings.TrimSpace(string(secret))) == 0 {
return fmt.Errorf("github auth config: client_secret_file %q is empty", config.ClientSecretFile)
}
}
if config.RedirectURL == "" {
return errors.New("github auth config: redirect_url is required")
}
if len(config.AllowedUsers) == 0 && len(config.AllowedOrgs) == 0 {
return errors.New("github auth config: at least one of allowed_users or allowed_orgs is required")
}
if len(config.AdminUsers) == 0 {
return errors.New("github auth config: admin_users is required")
}
return nil
}
func NormalizeUsernames(users []string) []string {
seen := make(map[string]bool)
result := make([]string, 0, len(users))
for _, user := range users {
normalized := strings.ToLower(strings.TrimSpace(user))
if normalized != "" && !seen[normalized] {
seen[normalized] = true
result = append(result, normalized)
}
}
return result
}

309
util/auth_config_test.go Normal file
View File

@ -0,0 +1,309 @@
package util
import (
"os"
"path/filepath"
"testing"
)
func TestParseAuthMethod_DefaultLocal(t *testing.T) {
t.Setenv("WGUI_AUTH_METHOD", "")
method := ParseAuthMethod()
if method != AuthMethodLocal {
t.Errorf("expected default auth method to be %q, got %q", AuthMethodLocal, method)
}
}
func TestParseAuthMethod_LocalExplicit(t *testing.T) {
t.Setenv("WGUI_AUTH_METHOD", "local")
method := ParseAuthMethod()
if method != AuthMethodLocal {
t.Errorf("expected auth method to be %q, got %q", AuthMethodLocal, method)
}
}
func TestParseAuthMethod_GitHubExplicit(t *testing.T) {
t.Setenv("WGUI_AUTH_METHOD", "github")
method := ParseAuthMethod()
if method != AuthMethodGitHub {
t.Errorf("expected auth method to be %q, got %q", AuthMethodGitHub, method)
}
}
func TestParseAuthMethod_GitHubCaseInsensitive(t *testing.T) {
t.Setenv("WGUI_AUTH_METHOD", "GitHub")
method := ParseAuthMethod()
if method != AuthMethodGitHub {
t.Errorf("expected case-insensitive auth method to be %q, got %q", AuthMethodGitHub, method)
}
}
func TestParseGitHubAuthConfig_NormalizesLists(t *testing.T) {
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", " User1 ,user1, USER2 ")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", " OrgA ,orga ")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", " Admin1 , admin1 ")
config := ParseGitHubAuthConfig()
if config.ClientID != "test-client-id" {
t.Fatalf("expected client_id to be parsed, got %q", config.ClientID)
}
if config.ClientSecret != "test-secret" {
t.Fatalf("expected client_secret to be parsed, got %q", config.ClientSecret)
}
if config.RedirectURL != "http://localhost:8080/callback" {
t.Fatalf("expected redirect_url to be parsed, got %q", config.RedirectURL)
}
expectedUsers := []string{"user1", "user2"}
if len(config.AllowedUsers) != len(expectedUsers) {
t.Fatalf("expected %d allowed users, got %d", len(expectedUsers), len(config.AllowedUsers))
}
for i, user := range expectedUsers {
if config.AllowedUsers[i] != user {
t.Fatalf("expected allowed_users[%d]=%q, got %q", i, user, config.AllowedUsers[i])
}
}
expectedOrgs := []string{"orga"}
if len(config.AllowedOrgs) != len(expectedOrgs) {
t.Fatalf("expected %d allowed orgs, got %d", len(expectedOrgs), len(config.AllowedOrgs))
}
for i, org := range expectedOrgs {
if config.AllowedOrgs[i] != org {
t.Fatalf("expected allowed_orgs[%d]=%q, got %q", i, org, config.AllowedOrgs[i])
}
}
expectedAdmins := []string{"admin1"}
if len(config.AdminUsers) != len(expectedAdmins) {
t.Fatalf("expected %d admin users, got %d", len(expectedAdmins), len(config.AdminUsers))
}
for i, admin := range expectedAdmins {
if config.AdminUsers[i] != admin {
t.Fatalf("expected admin_users[%d]=%q, got %q", i, admin, config.AdminUsers[i])
}
}
}
func TestValidateGitHubAuthConfig_MissingClientID(t *testing.T) {
t.Setenv("WGUI_GITHUB_CLIENT_ID", "")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for missing client_id, got nil")
}
}
func TestValidateGitHubAuthConfig_MissingClientSecret(t *testing.T) {
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{
ClientID: "test-client-id",
}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for missing client_secret and client_secret_file, got nil")
}
}
func TestValidateGitHubAuthConfig_MissingClientSecretFile(t *testing.T) {
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "secret")
os.WriteFile(secretFile, []byte("test-secret"), 0600)
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", secretFile)
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{
ClientID: "test-client-id",
ClientSecretFile: secretFile,
}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for missing client_secret and client_secret_file with no readable content, got nil")
}
}
func TestValidateGitHubAuthConfig_ClientSecretFileReadError(t *testing.T) {
nonExistentFile := "/nonexistent/path/to/secret"
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", nonExistentFile)
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{
ClientID: "test-client-id",
ClientSecretFile: nonExistentFile,
}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for unreadable client_secret_file, got nil")
}
}
func TestValidateGitHubAuthConfig_MissingRedirectURL(t *testing.T) {
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "secret")
os.WriteFile(secretFile, []byte("test-secret"), 0600)
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-secret",
}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for missing redirect_url, got nil")
}
}
func TestValidateGitHubAuthConfig_MissingAllowRules(t *testing.T) {
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "secret")
os.WriteFile(secretFile, []byte("test-secret"), 0600)
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-secret",
RedirectURL: "http://localhost:8080/callback",
}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for missing allow rules (allowed_users or allowed_orgs), got nil")
}
}
func TestValidateGitHubAuthConfig_MissingAdminUsers(t *testing.T) {
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "secret")
os.WriteFile(secretFile, []byte("test-secret"), 0600)
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "user1,user2")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
config := GitHubAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-secret",
RedirectURL: "http://localhost:8080/callback",
AllowedUsers: []string{"user1", "user2"},
AllowedOrgs: []string{},
}
err := ValidateGitHubAuthConfig(config)
if err == nil {
t.Error("expected error for missing admin_users, got nil")
}
}
func TestValidateGitHubAuthConfig_ValidMinimal(t *testing.T) {
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "secret")
os.WriteFile(secretFile, []byte("test-secret"), 0600)
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "user1,user2")
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "admin1")
config := GitHubAuthConfig{
ClientID: "test-client-id",
ClientSecret: "test-secret",
RedirectURL: "http://localhost:8080/callback",
AllowedUsers: []string{"user1", "user2"},
AdminUsers: []string{"admin1"},
}
err := ValidateGitHubAuthConfig(config)
if err != nil {
t.Errorf("expected no error for valid config, got %v", err)
}
}
func TestNormalizeUsernames_LowercaseAndTrim(t *testing.T) {
input := []string{" User1 ", "USER2", " user3 "}
expected := []string{"user1", "user2", "user3"}
result := NormalizeUsernames(input)
if len(result) != len(expected) {
t.Fatalf("expected length %d, got %d", len(expected), len(result))
}
for i, v := range result {
if v != expected[i] {
t.Errorf("expected[%d]=%q, got[%d]=%q", i, expected[i], i, v)
}
}
}
func TestNormalizeUsernames_Deduplicate(t *testing.T) {
input := []string{"User1", "user1", "USER1", "User2"}
expected := []string{"user1", "user2"}
result := NormalizeUsernames(input)
if len(result) != len(expected) {
t.Fatalf("expected length %d, got %d", len(expected), len(result))
}
for i, v := range result {
if v != expected[i] {
t.Errorf("expected[%d]=%q, got[%d]=%q", i, expected[i], i, v)
}
}
}
func TestNormalizeUsernames_Empty(t *testing.T) {
input := []string{}
expected := []string{}
result := NormalizeUsernames(input)
if len(result) != len(expected) {
t.Fatalf("expected length %d, got %d", len(expected), len(result))
}
}
func TestNormalizeUsernames_BlankEntriesDropped(t *testing.T) {
result := NormalizeUsernames([]string{"", " ", " user1 "})
if len(result) != 1 || result[0] != "user1" {
t.Fatalf("expected blank entries to be dropped, got %#v", result)
}
}

View File

@ -1,6 +1,4 @@
package util
import "sync"
var IPToSubnetRange = map[string]uint16{}
var DBUsersToCRC32 = map[string]uint32{}

View File

@ -9,25 +9,27 @@ import (
// Runtime config
var (
DisableLogin bool
BindAddress string
SmtpHostname string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpNoTLSCheck bool
SmtpEncryption string
SmtpAuthType string
SmtpHelo string
SendgridApiKey string
EmailFrom string
EmailFromName string
SessionSecret [64]byte
SessionMaxDuration int64
WgConfTemplate string
BasePath string
SubnetRanges map[string]([]*net.IPNet)
SubnetRangesOrder []string
DisableLogin bool
BindAddress string
SmtpHostname string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpNoTLSCheck bool
SmtpEncryption string
SmtpAuthType string
SmtpHelo string
SendgridApiKey string
EmailFrom string
EmailFromName string
SessionSecret [64]byte
SessionMaxDuration int64
WgConfTemplate string
BasePath string
SubnetRanges map[string]([]*net.IPNet)
SubnetRangesOrder []string
CurrentAuthMethod AuthMethod
CurrentGitHubConfig GitHubAuthConfig
)
const (

View File

@ -2,8 +2,7 @@ package util
import (
"bufio"
"bytes"
"encoding/gob"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
@ -727,13 +726,28 @@ func ManagePerms(path string) error {
return err
}
// GetDBUserCRC32 builds a stable session hash from security-relevant user fields.
// DisplayName is intentionally excluded so profile presentation changes do not invalidate sessions.
func GetDBUserCRC32(dbuser model.User) uint32 {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
if err := enc.Encode(dbuser); err != nil {
panic("model.User is gob-incompatible, session verification is impossible")
h := crc32.NewIEEE()
writeHashField(h, dbuser.Username)
if dbuser.Admin {
writeHashField(h, "1")
} else {
writeHashField(h, "0")
}
return crc32.ChecksumIEEE(buf.Bytes())
writeHashField(h, dbuser.Password)
writeHashField(h, dbuser.PasswordHash)
writeHashField(h, dbuser.AuthSource)
writeHashField(h, dbuser.AuthSubject)
return h.Sum32()
}
func writeHashField(h io.Writer, value string) {
var length [4]byte
binary.BigEndian.PutUint32(length[:], uint32(len(value)))
_, _ = h.Write(length[:])
_, _ = h.Write([]byte(value))
}
func ConcatMultipleSlices(slices ...[]byte) []byte {

176
util/util_test.go Normal file
View File

@ -0,0 +1,176 @@
package util
import (
"testing"
"github.com/ngoduykhanh/wireguard-ui/model"
)
func TestGetDBUserCRC32_IgnoresDisplayName(t *testing.T) {
user1 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Test User",
}
user2 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Different Name",
}
hash1 := GetDBUserCRC32(user1)
hash2 := GetDBUserCRC32(user2)
if hash1 != hash2 {
t.Errorf("CRC32 should be equal when only DisplayName differs: hash1=%d, hash2=%d", hash1, hash2)
}
}
func TestGetDBUserCRC32_ChangesWhenAdminChanges(t *testing.T) {
user1 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Test User",
}
user2 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: false,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Test User",
}
hash1 := GetDBUserCRC32(user1)
hash2 := GetDBUserCRC32(user2)
if hash1 == hash2 {
t.Errorf("CRC32 should change when Admin field changes: hash1=%d, hash2=%d", hash1, hash2)
}
}
func TestGetDBUserCRC32_ChangesWhenPasswordChanges(t *testing.T) {
user1 := model.User{
Username: "testuser",
Password: "password1",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Test User",
}
user2 := model.User{
Username: "testuser",
Password: "password2",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Test User",
}
hash1 := GetDBUserCRC32(user1)
hash2 := GetDBUserCRC32(user2)
if hash1 == hash2 {
t.Errorf("CRC32 should change when Password field changes: hash1=%d, hash2=%d", hash1, hash2)
}
}
func TestGetDBUserCRC32_ChangesWhenAuthSourceChanges(t *testing.T) {
user1 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "local",
AuthSubject: "",
DisplayName: "Test User",
}
user2 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "github",
AuthSubject: "12345",
DisplayName: "Test User",
}
hash1 := GetDBUserCRC32(user1)
hash2 := GetDBUserCRC32(user2)
if hash1 == hash2 {
t.Errorf("CRC32 should change when AuthSource field changes: hash1=%d, hash2=%d", hash1, hash2)
}
}
func TestGetDBUserCRC32_ChangesWhenAuthSubjectChanges(t *testing.T) {
user1 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "github",
AuthSubject: "12345",
DisplayName: "Test User",
}
user2 := model.User{
Username: "testuser",
Password: "password",
PasswordHash: "hash",
Admin: true,
AuthSource: "github",
AuthSubject: "67890",
DisplayName: "Test User",
}
hash1 := GetDBUserCRC32(user1)
hash2 := GetDBUserCRC32(user2)
if hash1 == hash2 {
t.Errorf("CRC32 should change when AuthSubject field changes: hash1=%d, hash2=%d", hash1, hash2)
}
}
func TestGetDBUserCRC32_DistinguishesEmbeddedNulls(t *testing.T) {
user1 := model.User{
Username: "user\x00name",
Password: "pass",
PasswordHash: "hash",
Admin: false,
AuthSource: "github",
AuthSubject: "subject",
}
user2 := model.User{
Username: "user",
Password: "\x00namepass",
PasswordHash: "hash",
Admin: false,
AuthSource: "github",
AuthSubject: "subject",
}
if GetDBUserCRC32(user1) == GetDBUserCRC32(user2) {
t.Fatal("CRC32 should distinguish values even when fields contain embedded null bytes")
}
}