feat: add GitHub auth configuration and session foundation
This commit is contained in:
parent
9d7948417b
commit
bdbb9250b8
|
|
@ -13,8 +13,6 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/gommon/log"
|
||||
"github.com/rs/xid"
|
||||
|
|
@ -90,43 +88,10 @@ func Login(db store.IStore) echo.HandlerFunc {
|
|||
}
|
||||
|
||||
if userCorrect && passwordCorrect {
|
||||
ageMax := 0
|
||||
if rememberMe {
|
||||
ageMax = 86400 * 7
|
||||
if err := establishAuthenticatedSession(c, dbuser, rememberMe); err != nil {
|
||||
log.Errorf("Failed to establish authenticated session: %v", err)
|
||||
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Failed to establish session"})
|
||||
}
|
||||
|
||||
cookiePath := util.GetCookiePath()
|
||||
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Options = &sessions.Options{
|
||||
Path: cookiePath,
|
||||
MaxAge: ageMax,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
// set session_token
|
||||
tokenUID := xid.New().String()
|
||||
now := time.Now().UTC().Unix()
|
||||
sess.Values["username"] = dbuser.Username
|
||||
sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser)
|
||||
sess.Values["admin"] = dbuser.Admin
|
||||
sess.Values["session_token"] = tokenUID
|
||||
sess.Values["max_age"] = ageMax
|
||||
sess.Values["created_at"] = now
|
||||
sess.Values["updated_at"] = now
|
||||
sess.Save(c.Request(), c.Response())
|
||||
|
||||
// set session_token in cookie
|
||||
cookie := new(http.Cookie)
|
||||
cookie.Name = "session_token"
|
||||
cookie.Path = cookiePath
|
||||
cookie.Value = tokenUID
|
||||
cookie.MaxAge = ageMax
|
||||
cookie.HttpOnly = true
|
||||
cookie.SameSite = http.SameSiteLaxMode
|
||||
c.SetCookie(cookie)
|
||||
|
||||
return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"})
|
||||
}
|
||||
|
||||
|
|
@ -588,6 +553,7 @@ func UpdateClient(db store.IStore) echo.HandlerFunc {
|
|||
}
|
||||
client := *clientData.Client
|
||||
|
||||
allocatedIPs, err := util.GetAllocatedIPs("")
|
||||
check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, _client.AllocatedIPs)
|
||||
if !check {
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)})
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@ import (
|
|||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/ngoduykhanh/wireguard-ui/model"
|
||||
"github.com/ngoduykhanh/wireguard-ui/util"
|
||||
)
|
||||
|
||||
|
|
@ -247,3 +250,47 @@ func clearSession(c echo.Context) {
|
|||
cookie.SameSite = http.SameSiteLaxMode
|
||||
c.SetCookie(cookie)
|
||||
}
|
||||
|
||||
func establishAuthenticatedSession(c echo.Context, user model.User, rememberMe bool) error {
|
||||
sess, err := session.Get("session", c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ageMax := 0
|
||||
if rememberMe {
|
||||
ageMax = 86400 * 7
|
||||
}
|
||||
|
||||
cookiePath := util.GetCookiePath()
|
||||
|
||||
sess.Options = &sessions.Options{
|
||||
Path: cookiePath,
|
||||
MaxAge: ageMax,
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
|
||||
tokenUID := xid.New().String()
|
||||
now := time.Now().UTC().Unix()
|
||||
|
||||
sess.Values["username"] = user.Username
|
||||
sess.Values["user_hash"] = util.GetDBUserCRC32(user)
|
||||
sess.Values["admin"] = user.Admin
|
||||
sess.Values["session_token"] = tokenUID
|
||||
sess.Values["max_age"] = ageMax
|
||||
sess.Values["created_at"] = now
|
||||
sess.Values["updated_at"] = now
|
||||
sess.Save(c.Request(), c.Response())
|
||||
|
||||
cookie := new(http.Cookie)
|
||||
cookie.Name = "session_token"
|
||||
cookie.Path = cookiePath
|
||||
cookie.Value = tokenUID
|
||||
cookie.MaxAge = ageMax
|
||||
cookie.HttpOnly = true
|
||||
cookie.SameSite = http.SameSiteLaxMode
|
||||
c.SetCookie(cookie)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,4 +7,7 @@ type User struct {
|
|||
// PasswordHash takes precedence over Password.
|
||||
PasswordHash string `json:"password_hash"`
|
||||
Admin bool `json:"admin"`
|
||||
AuthSource string `json:"auth_source"`
|
||||
AuthSubject string `json:"auth_subject"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,374 @@
|
|||
package jsondb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ngoduykhanh/wireguard-ui/model"
|
||||
"github.com/sdomino/scribble"
|
||||
)
|
||||
|
||||
func TestGetUserByAuthIdentity_FindsGitHubUser(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
githubUser := model.User{
|
||||
Username: "github-user",
|
||||
Password: "",
|
||||
PasswordHash: "",
|
||||
Admin: false,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "123456",
|
||||
DisplayName: "GitHub User",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(githubUser); err != nil {
|
||||
t.Fatalf("failed to save user: %v", err)
|
||||
}
|
||||
|
||||
found, err := db.GetUserByAuthIdentity("github", "123456")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserByAuthIdentity failed: %v", err)
|
||||
}
|
||||
|
||||
if found.Username != "github-user" {
|
||||
t.Errorf("expected username 'github-user', got %q", found.Username)
|
||||
}
|
||||
if found.AuthSource != "github" {
|
||||
t.Errorf("expected auth_source 'github', got %q", found.AuthSource)
|
||||
}
|
||||
if found.AuthSubject != "123456" {
|
||||
t.Errorf("expected auth_subject '123456', got %q", found.AuthSubject)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByAuthIdentity_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetUserByAuthIdentity("github", "nonexistent")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent auth identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserByAuthIdentity_LocalUserNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
localUser := model.User{
|
||||
Username: "local-user",
|
||||
Password: "secret",
|
||||
PasswordHash: "hash",
|
||||
Admin: false,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Local User",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(localUser); err != nil {
|
||||
t.Fatalf("failed to save user: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetUserByAuthIdentity("local", "")
|
||||
if err == nil {
|
||||
t.Error("expected error for local user queried by empty auth identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUser_RenameKeepsOnlyNewRecord(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
oldUser := model.User{
|
||||
Username: "oldname",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Old Name",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(oldUser); err != nil {
|
||||
t.Fatalf("failed to save user: %v", err)
|
||||
}
|
||||
|
||||
newUser := model.User{
|
||||
Username: "newname",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "New Name",
|
||||
}
|
||||
|
||||
if err := db.ReplaceUser("oldname", newUser); err != nil {
|
||||
t.Fatalf("ReplaceUser failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetUserByName("oldname")
|
||||
if err == nil {
|
||||
t.Error("old username should not exist after rename")
|
||||
}
|
||||
|
||||
_, err = db.GetUserByName("newname")
|
||||
if err != nil {
|
||||
t.Errorf("new username should exist after rename: %v", err)
|
||||
}
|
||||
|
||||
users, err := db.GetUsers()
|
||||
if err != nil {
|
||||
t.Fatalf("GetUsers failed: %v", err)
|
||||
}
|
||||
if len(users) != 1 {
|
||||
t.Errorf("expected exactly 1 user after rename, got %d", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUser_SameNameUpdatesRecord(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
originalUser := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Original",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(originalUser); err != nil {
|
||||
t.Fatalf("failed to save user: %v", err)
|
||||
}
|
||||
|
||||
updatedUser := model.User{
|
||||
Username: "testuser",
|
||||
Password: "newpassword",
|
||||
PasswordHash: "newhash",
|
||||
Admin: false,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Updated",
|
||||
}
|
||||
|
||||
if err := db.ReplaceUser("testuser", updatedUser); err != nil {
|
||||
t.Fatalf("ReplaceUser failed: %v", err)
|
||||
}
|
||||
|
||||
found, err := db.GetUserByName("testuser")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserByName failed: %v", err)
|
||||
}
|
||||
|
||||
if found.Admin != false {
|
||||
t.Errorf("expected admin=false, got %v", found.Admin)
|
||||
}
|
||||
if found.DisplayName != "Updated" {
|
||||
t.Errorf("expected display_name='Updated', got %q", found.DisplayName)
|
||||
}
|
||||
|
||||
users, err := db.GetUsers()
|
||||
if err != nil {
|
||||
t.Fatalf("GetUsers failed: %v", err)
|
||||
}
|
||||
if len(users) != 1 {
|
||||
t.Errorf("expected exactly 1 user after update, got %d", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUser_MigratesLegacyLocalUser(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
legacyUser := model.User{
|
||||
Username: "legacyuser",
|
||||
Password: "oldpassword",
|
||||
PasswordHash: "oldhash",
|
||||
Admin: false,
|
||||
AuthSource: "",
|
||||
AuthSubject: "",
|
||||
DisplayName: "",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(legacyUser); err != nil {
|
||||
t.Fatalf("failed to save legacy user: %v", err)
|
||||
}
|
||||
|
||||
migratedUser := model.User{
|
||||
Username: "legacyuser",
|
||||
Password: "",
|
||||
PasswordHash: "",
|
||||
Admin: false,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "987654",
|
||||
DisplayName: "Migrated User",
|
||||
}
|
||||
|
||||
if err := db.ReplaceUser("legacyuser", migratedUser); err != nil {
|
||||
t.Fatalf("ReplaceUser failed: %v", err)
|
||||
}
|
||||
|
||||
found, err := db.GetUserByAuthIdentity("github", "987654")
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserByAuthIdentity failed: %v", err)
|
||||
}
|
||||
|
||||
if found.Username != "legacyuser" {
|
||||
t.Errorf("expected username 'legacyuser', got %q", found.Username)
|
||||
}
|
||||
if found.AuthSource != "github" {
|
||||
t.Errorf("expected auth_source 'github', got %q", found.AuthSource)
|
||||
}
|
||||
if found.AuthSubject != "987654" {
|
||||
t.Errorf("expected auth_subject '987654', got %q", found.AuthSubject)
|
||||
}
|
||||
if found.Password != "" || found.PasswordHash != "" {
|
||||
t.Errorf("expected empty password fields for migrated user, got password=%q, password_hash=%q", found.Password, found.PasswordHash)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUser_WriteNewFails_KeepsOldRecord(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
existingUser := model.User{
|
||||
Username: "existing",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Existing User",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(existingUser); err != nil {
|
||||
t.Fatalf("failed to save user: %v", err)
|
||||
}
|
||||
|
||||
userDir := filepath.Join(tmpDir, "users")
|
||||
os.Chmod(userDir, 0000)
|
||||
|
||||
newUser := model.User{
|
||||
Username: "newuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: false,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "123456",
|
||||
DisplayName: "New User",
|
||||
}
|
||||
|
||||
err = db.ReplaceUser("existing", newUser)
|
||||
|
||||
os.Chmod(userDir, 0700)
|
||||
|
||||
if err == nil {
|
||||
t.Error("expected error when writing new record fails")
|
||||
}
|
||||
|
||||
_, err = db.GetUserByName("existing")
|
||||
if err != nil {
|
||||
t.Errorf("old record should still exist after failed replace: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetUserByName("newuser")
|
||||
if err == nil {
|
||||
t.Error("new record should not exist after failed replace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceUser_DeleteOldFails_AttemptsRollbackBeforeManualCleanupError(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
db, err := New(tmpDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create db: %v", err)
|
||||
}
|
||||
|
||||
oldUser := model.User{
|
||||
Username: "olduser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Old User",
|
||||
}
|
||||
|
||||
if err := db.SaveUser(oldUser); err != nil {
|
||||
t.Fatalf("failed to save user: %v", err)
|
||||
}
|
||||
|
||||
newUser := model.User{
|
||||
Username: "newuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "New User",
|
||||
}
|
||||
|
||||
origDeleteUserFile := deleteUserFile
|
||||
deleteUserFile = func(conn *scribble.Driver, username string) error {
|
||||
if username == "olduser" {
|
||||
return fmt.Errorf("simulated delete failure for olduser")
|
||||
}
|
||||
if username == "newuser" {
|
||||
return fmt.Errorf("simulated rollback failure for newuser")
|
||||
}
|
||||
return origDeleteUserFile(conn, username)
|
||||
}
|
||||
defer func() { deleteUserFile = origDeleteUserFile }()
|
||||
|
||||
err = db.ReplaceUser("olduser", newUser)
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected error when delete old fails and rollback also fails")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "manual cleanup required") {
|
||||
t.Errorf("error should contain 'manual cleanup required', got: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "rollback failed") {
|
||||
t.Errorf("error should contain 'rollback failed', got: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetUserByName("olduser")
|
||||
if err != nil {
|
||||
t.Errorf("old record should still exist: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetUserByName("newuser")
|
||||
if err != nil {
|
||||
t.Errorf("new record should still exist after failed rollback: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -16,6 +16,10 @@ import (
|
|||
"github.com/ngoduykhanh/wireguard-ui/util"
|
||||
)
|
||||
|
||||
var deleteUserFile = func(conn *scribble.Driver, username string) error {
|
||||
return conn.Delete("users", username)
|
||||
}
|
||||
|
||||
type JsonDB struct {
|
||||
conn *scribble.Driver
|
||||
dbPath string
|
||||
|
|
@ -216,7 +220,78 @@ func (o *JsonDB) SaveUser(user model.User) error {
|
|||
// DeleteUser func to remove user from the database
|
||||
func (o *JsonDB) DeleteUser(username string) error {
|
||||
delete(util.DBUsersToCRC32, username)
|
||||
return o.conn.Delete("users", username)
|
||||
return deleteUserFile(o.conn, username)
|
||||
}
|
||||
|
||||
// GetUserByAuthIdentity func to get a user by auth source and subject
|
||||
func (o *JsonDB) GetUserByAuthIdentity(authSource, authSubject string) (model.User, error) {
|
||||
if authSubject == "" {
|
||||
return model.User{}, fmt.Errorf("auth subject cannot be empty for auth source %q", authSource)
|
||||
}
|
||||
|
||||
users, err := o.GetUsers()
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
for _, user := range users {
|
||||
if user.AuthSource == authSource && user.AuthSubject == authSubject {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
return model.User{}, fmt.Errorf("user not found for auth source %q and subject %q", authSource, authSubject)
|
||||
}
|
||||
|
||||
// ReplaceUser safely replaces or renames a user record.
|
||||
// If oldUsername == user.Username, it performs an update.
|
||||
// If they differ, it writes the new record first, then deletes the old one.
|
||||
// If the old record deletion fails after the new record was written, the new
|
||||
// record remains and an error indicating manual cleanup is returned.
|
||||
func (o *JsonDB) ReplaceUser(oldUsername string, user model.User) error {
|
||||
userPath := path.Join(path.Join(o.dbPath, "users"), user.Username+".json")
|
||||
oldUserPath := path.Join(path.Join(o.dbPath, "users"), oldUsername+".json")
|
||||
|
||||
isRename := oldUsername != user.Username
|
||||
|
||||
if isRename {
|
||||
_, err := os.Stat(oldUserPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("old user record does not exist: %w", err)
|
||||
}
|
||||
|
||||
_, err = os.Stat(userPath)
|
||||
if err == nil {
|
||||
return fmt.Errorf("cannot rename to %q: a record with that name already exists", user.Username)
|
||||
}
|
||||
}
|
||||
|
||||
if err := o.conn.Write("users", user.Username, user); err != nil {
|
||||
return fmt.Errorf("failed to write new user record: %w", err)
|
||||
}
|
||||
|
||||
if err := util.ManagePerms(userPath); err != nil {
|
||||
os.Remove(userPath)
|
||||
return fmt.Errorf("failed to set permissions on new user record: %w", err)
|
||||
}
|
||||
|
||||
util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user)
|
||||
|
||||
if isRename {
|
||||
delete(util.DBUsersToCRC32, oldUsername)
|
||||
|
||||
if err := deleteUserFile(o.conn, oldUsername); err != nil {
|
||||
rollbackErr := deleteUserFile(o.conn, user.Username)
|
||||
if rollbackErr != nil {
|
||||
return fmt.Errorf("failed to delete old user record %q (new record %q written): %v; rollback failed: %v; manual cleanup required: remove %s and ensure %s exists", oldUsername, user.Username, err, rollbackErr, userPath, oldUserPath)
|
||||
}
|
||||
delete(util.DBUsersToCRC32, user.Username)
|
||||
os.RemoveAll(userPath)
|
||||
return fmt.Errorf("failed to delete old user record %q (new record %q was written and rolled back): %v", oldUsername, user.Username, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGlobalSettings func to query global settings from the database
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ type IStore interface {
|
|||
Init() error
|
||||
GetUsers() ([]model.User, error)
|
||||
GetUserByName(username string) (model.User, error)
|
||||
GetUserByAuthIdentity(authSource, authSubject string) (model.User, error)
|
||||
SaveUser(user model.User) error
|
||||
ReplaceUser(oldUsername string, user model.User) error
|
||||
DeleteUser(username string) error
|
||||
GetGlobalSettings() (model.GlobalSetting, error)
|
||||
GetServer() (model.Server, error)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AuthMethod string
|
||||
|
||||
const (
|
||||
AuthMethodLocal AuthMethod = "local"
|
||||
AuthMethodGitHub AuthMethod = "github"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthMethodEnvVar = "WGUI_AUTH_METHOD"
|
||||
GitHubClientIDEnvVar = "WGUI_GITHUB_CLIENT_ID"
|
||||
GitHubClientSecretEnvVar = "WGUI_GITHUB_CLIENT_SECRET"
|
||||
GitHubClientSecretFileEnvVar = "WGUI_GITHUB_CLIENT_SECRET_FILE"
|
||||
GitHubRedirectURLEnvVar = "WGUI_GITHUB_REDIRECT_URL"
|
||||
GitHubAllowedUsersEnvVar = "WGUI_GITHUB_ALLOWED_USERS"
|
||||
GitHubAllowedOrgsEnvVar = "WGUI_GITHUB_ALLOWED_ORGS"
|
||||
GitHubAdminUsersEnvVar = "WGUI_GITHUB_ADMIN_USERS"
|
||||
)
|
||||
|
||||
type GitHubAuthConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
ClientSecretFile string
|
||||
RedirectURL string
|
||||
AllowedUsers []string
|
||||
AllowedOrgs []string
|
||||
AdminUsers []string
|
||||
}
|
||||
|
||||
func ParseAuthMethod() AuthMethod {
|
||||
method := LookupEnvOrString(AuthMethodEnvVar, "local")
|
||||
switch strings.ToLower(method) {
|
||||
case "github":
|
||||
return AuthMethodGitHub
|
||||
default:
|
||||
return AuthMethodLocal
|
||||
}
|
||||
}
|
||||
|
||||
func ParseGitHubAuthConfig() GitHubAuthConfig {
|
||||
clientID := LookupEnvOrString(GitHubClientIDEnvVar, "")
|
||||
clientSecret := LookupEnvOrString(GitHubClientSecretEnvVar, "")
|
||||
clientSecretFile := LookupEnvOrString(GitHubClientSecretFileEnvVar, "")
|
||||
redirectURL := LookupEnvOrString(GitHubRedirectURLEnvVar, "")
|
||||
allowedUsersRaw := LookupEnvOrString(GitHubAllowedUsersEnvVar, "")
|
||||
allowedOrgsRaw := LookupEnvOrString(GitHubAllowedOrgsEnvVar, "")
|
||||
adminUsersRaw := LookupEnvOrString(GitHubAdminUsersEnvVar, "")
|
||||
|
||||
var allowedUsers, allowedOrgs, adminUsers []string
|
||||
if allowedUsersRaw != "" {
|
||||
allowedUsers = NormalizeUsernames(strings.Split(allowedUsersRaw, ","))
|
||||
}
|
||||
if allowedOrgsRaw != "" {
|
||||
allowedOrgs = NormalizeUsernames(strings.Split(allowedOrgsRaw, ","))
|
||||
}
|
||||
if adminUsersRaw != "" {
|
||||
adminUsers = NormalizeUsernames(strings.Split(adminUsersRaw, ","))
|
||||
}
|
||||
|
||||
return GitHubAuthConfig{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
ClientSecretFile: clientSecretFile,
|
||||
RedirectURL: redirectURL,
|
||||
AllowedUsers: allowedUsers,
|
||||
AllowedOrgs: allowedOrgs,
|
||||
AdminUsers: adminUsers,
|
||||
}
|
||||
}
|
||||
|
||||
func ValidateGitHubAuthConfig(config GitHubAuthConfig) error {
|
||||
if config.ClientID == "" {
|
||||
return errors.New("github auth config: client_id is required")
|
||||
}
|
||||
|
||||
if config.ClientSecret == "" && config.ClientSecretFile == "" {
|
||||
return errors.New("github auth config: client_secret is required")
|
||||
}
|
||||
|
||||
if config.ClientSecretFile != "" {
|
||||
secret, err := os.ReadFile(config.ClientSecretFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("github auth config: reading client_secret_file %q: %w", config.ClientSecretFile, err)
|
||||
}
|
||||
if len(strings.TrimSpace(string(secret))) == 0 {
|
||||
return fmt.Errorf("github auth config: client_secret_file %q is empty", config.ClientSecretFile)
|
||||
}
|
||||
}
|
||||
|
||||
if config.RedirectURL == "" {
|
||||
return errors.New("github auth config: redirect_url is required")
|
||||
}
|
||||
|
||||
if len(config.AllowedUsers) == 0 && len(config.AllowedOrgs) == 0 {
|
||||
return errors.New("github auth config: at least one of allowed_users or allowed_orgs is required")
|
||||
}
|
||||
|
||||
if len(config.AdminUsers) == 0 {
|
||||
return errors.New("github auth config: admin_users is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeUsernames(users []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
result := make([]string, 0, len(users))
|
||||
for _, user := range users {
|
||||
normalized := strings.ToLower(strings.TrimSpace(user))
|
||||
if normalized != "" && !seen[normalized] {
|
||||
seen[normalized] = true
|
||||
result = append(result, normalized)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
@ -0,0 +1,309 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseAuthMethod_DefaultLocal(t *testing.T) {
|
||||
t.Setenv("WGUI_AUTH_METHOD", "")
|
||||
method := ParseAuthMethod()
|
||||
if method != AuthMethodLocal {
|
||||
t.Errorf("expected default auth method to be %q, got %q", AuthMethodLocal, method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuthMethod_LocalExplicit(t *testing.T) {
|
||||
t.Setenv("WGUI_AUTH_METHOD", "local")
|
||||
method := ParseAuthMethod()
|
||||
if method != AuthMethodLocal {
|
||||
t.Errorf("expected auth method to be %q, got %q", AuthMethodLocal, method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuthMethod_GitHubExplicit(t *testing.T) {
|
||||
t.Setenv("WGUI_AUTH_METHOD", "github")
|
||||
method := ParseAuthMethod()
|
||||
if method != AuthMethodGitHub {
|
||||
t.Errorf("expected auth method to be %q, got %q", AuthMethodGitHub, method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuthMethod_GitHubCaseInsensitive(t *testing.T) {
|
||||
t.Setenv("WGUI_AUTH_METHOD", "GitHub")
|
||||
method := ParseAuthMethod()
|
||||
if method != AuthMethodGitHub {
|
||||
t.Errorf("expected case-insensitive auth method to be %q, got %q", AuthMethodGitHub, method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGitHubAuthConfig_NormalizesLists(t *testing.T) {
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", " User1 ,user1, USER2 ")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", " OrgA ,orga ")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", " Admin1 , admin1 ")
|
||||
|
||||
config := ParseGitHubAuthConfig()
|
||||
|
||||
if config.ClientID != "test-client-id" {
|
||||
t.Fatalf("expected client_id to be parsed, got %q", config.ClientID)
|
||||
}
|
||||
if config.ClientSecret != "test-secret" {
|
||||
t.Fatalf("expected client_secret to be parsed, got %q", config.ClientSecret)
|
||||
}
|
||||
if config.RedirectURL != "http://localhost:8080/callback" {
|
||||
t.Fatalf("expected redirect_url to be parsed, got %q", config.RedirectURL)
|
||||
}
|
||||
|
||||
expectedUsers := []string{"user1", "user2"}
|
||||
if len(config.AllowedUsers) != len(expectedUsers) {
|
||||
t.Fatalf("expected %d allowed users, got %d", len(expectedUsers), len(config.AllowedUsers))
|
||||
}
|
||||
for i, user := range expectedUsers {
|
||||
if config.AllowedUsers[i] != user {
|
||||
t.Fatalf("expected allowed_users[%d]=%q, got %q", i, user, config.AllowedUsers[i])
|
||||
}
|
||||
}
|
||||
|
||||
expectedOrgs := []string{"orga"}
|
||||
if len(config.AllowedOrgs) != len(expectedOrgs) {
|
||||
t.Fatalf("expected %d allowed orgs, got %d", len(expectedOrgs), len(config.AllowedOrgs))
|
||||
}
|
||||
for i, org := range expectedOrgs {
|
||||
if config.AllowedOrgs[i] != org {
|
||||
t.Fatalf("expected allowed_orgs[%d]=%q, got %q", i, org, config.AllowedOrgs[i])
|
||||
}
|
||||
}
|
||||
|
||||
expectedAdmins := []string{"admin1"}
|
||||
if len(config.AdminUsers) != len(expectedAdmins) {
|
||||
t.Fatalf("expected %d admin users, got %d", len(expectedAdmins), len(config.AdminUsers))
|
||||
}
|
||||
for i, admin := range expectedAdmins {
|
||||
if config.AdminUsers[i] != admin {
|
||||
t.Fatalf("expected admin_users[%d]=%q, got %q", i, admin, config.AdminUsers[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_MissingClientID(t *testing.T) {
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing client_id, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_MissingClientSecret(t *testing.T) {
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing client_secret and client_secret_file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_MissingClientSecretFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "secret")
|
||||
os.WriteFile(secretFile, []byte("test-secret"), 0600)
|
||||
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", secretFile)
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecretFile: secretFile,
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing client_secret and client_secret_file with no readable content, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_ClientSecretFileReadError(t *testing.T) {
|
||||
nonExistentFile := "/nonexistent/path/to/secret"
|
||||
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", nonExistentFile)
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecretFile: nonExistentFile,
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for unreadable client_secret_file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_MissingRedirectURL(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "secret")
|
||||
os.WriteFile(secretFile, []byte("test-secret"), 0600)
|
||||
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-secret",
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing redirect_url, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_MissingAllowRules(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "secret")
|
||||
os.WriteFile(secretFile, []byte("test-secret"), 0600)
|
||||
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-secret",
|
||||
RedirectURL: "http://localhost:8080/callback",
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing allow rules (allowed_users or allowed_orgs), got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_MissingAdminUsers(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "secret")
|
||||
os.WriteFile(secretFile, []byte("test-secret"), 0600)
|
||||
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "user1,user2")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-secret",
|
||||
RedirectURL: "http://localhost:8080/callback",
|
||||
AllowedUsers: []string{"user1", "user2"},
|
||||
AllowedOrgs: []string{},
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing admin_users, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGitHubAuthConfig_ValidMinimal(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "secret")
|
||||
os.WriteFile(secretFile, []byte("test-secret"), 0600)
|
||||
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id")
|
||||
t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret")
|
||||
t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "user1,user2")
|
||||
t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "")
|
||||
t.Setenv("WGUI_GITHUB_ADMIN_USERS", "admin1")
|
||||
|
||||
config := GitHubAuthConfig{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-secret",
|
||||
RedirectURL: "http://localhost:8080/callback",
|
||||
AllowedUsers: []string{"user1", "user2"},
|
||||
AdminUsers: []string{"admin1"},
|
||||
}
|
||||
err := ValidateGitHubAuthConfig(config)
|
||||
if err != nil {
|
||||
t.Errorf("expected no error for valid config, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUsernames_LowercaseAndTrim(t *testing.T) {
|
||||
input := []string{" User1 ", "USER2", " user3 "}
|
||||
expected := []string{"user1", "user2", "user3"}
|
||||
result := NormalizeUsernames(input)
|
||||
if len(result) != len(expected) {
|
||||
t.Fatalf("expected length %d, got %d", len(expected), len(result))
|
||||
}
|
||||
for i, v := range result {
|
||||
if v != expected[i] {
|
||||
t.Errorf("expected[%d]=%q, got[%d]=%q", i, expected[i], i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUsernames_Deduplicate(t *testing.T) {
|
||||
input := []string{"User1", "user1", "USER1", "User2"}
|
||||
expected := []string{"user1", "user2"}
|
||||
result := NormalizeUsernames(input)
|
||||
if len(result) != len(expected) {
|
||||
t.Fatalf("expected length %d, got %d", len(expected), len(result))
|
||||
}
|
||||
for i, v := range result {
|
||||
if v != expected[i] {
|
||||
t.Errorf("expected[%d]=%q, got[%d]=%q", i, expected[i], i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUsernames_Empty(t *testing.T) {
|
||||
input := []string{}
|
||||
expected := []string{}
|
||||
result := NormalizeUsernames(input)
|
||||
if len(result) != len(expected) {
|
||||
t.Fatalf("expected length %d, got %d", len(expected), len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeUsernames_BlankEntriesDropped(t *testing.T) {
|
||||
result := NormalizeUsernames([]string{"", " ", " user1 "})
|
||||
if len(result) != 1 || result[0] != "user1" {
|
||||
t.Fatalf("expected blank entries to be dropped, got %#v", result)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,4 @@
|
|||
package util
|
||||
|
||||
import "sync"
|
||||
|
||||
var IPToSubnetRange = map[string]uint16{}
|
||||
var DBUsersToCRC32 = map[string]uint32{}
|
||||
|
|
|
|||
|
|
@ -9,25 +9,27 @@ import (
|
|||
|
||||
// Runtime config
|
||||
var (
|
||||
DisableLogin bool
|
||||
BindAddress string
|
||||
SmtpHostname string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpNoTLSCheck bool
|
||||
SmtpEncryption string
|
||||
SmtpAuthType string
|
||||
SmtpHelo string
|
||||
SendgridApiKey string
|
||||
EmailFrom string
|
||||
EmailFromName string
|
||||
SessionSecret [64]byte
|
||||
SessionMaxDuration int64
|
||||
WgConfTemplate string
|
||||
BasePath string
|
||||
SubnetRanges map[string]([]*net.IPNet)
|
||||
SubnetRangesOrder []string
|
||||
DisableLogin bool
|
||||
BindAddress string
|
||||
SmtpHostname string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpNoTLSCheck bool
|
||||
SmtpEncryption string
|
||||
SmtpAuthType string
|
||||
SmtpHelo string
|
||||
SendgridApiKey string
|
||||
EmailFrom string
|
||||
EmailFromName string
|
||||
SessionSecret [64]byte
|
||||
SessionMaxDuration int64
|
||||
WgConfTemplate string
|
||||
BasePath string
|
||||
SubnetRanges map[string]([]*net.IPNet)
|
||||
SubnetRangesOrder []string
|
||||
CurrentAuthMethod AuthMethod
|
||||
CurrentGitHubConfig GitHubAuthConfig
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
|
|||
28
util/util.go
28
util/util.go
|
|
@ -2,8 +2,7 @@ package util
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -727,13 +726,28 @@ func ManagePerms(path string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// GetDBUserCRC32 builds a stable session hash from security-relevant user fields.
|
||||
// DisplayName is intentionally excluded so profile presentation changes do not invalidate sessions.
|
||||
func GetDBUserCRC32(dbuser model.User) uint32 {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
if err := enc.Encode(dbuser); err != nil {
|
||||
panic("model.User is gob-incompatible, session verification is impossible")
|
||||
h := crc32.NewIEEE()
|
||||
writeHashField(h, dbuser.Username)
|
||||
if dbuser.Admin {
|
||||
writeHashField(h, "1")
|
||||
} else {
|
||||
writeHashField(h, "0")
|
||||
}
|
||||
return crc32.ChecksumIEEE(buf.Bytes())
|
||||
writeHashField(h, dbuser.Password)
|
||||
writeHashField(h, dbuser.PasswordHash)
|
||||
writeHashField(h, dbuser.AuthSource)
|
||||
writeHashField(h, dbuser.AuthSubject)
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
func writeHashField(h io.Writer, value string) {
|
||||
var length [4]byte
|
||||
binary.BigEndian.PutUint32(length[:], uint32(len(value)))
|
||||
_, _ = h.Write(length[:])
|
||||
_, _ = h.Write([]byte(value))
|
||||
}
|
||||
|
||||
func ConcatMultipleSlices(slices ...[]byte) []byte {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,176 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ngoduykhanh/wireguard-ui/model"
|
||||
)
|
||||
|
||||
func TestGetDBUserCRC32_IgnoresDisplayName(t *testing.T) {
|
||||
user1 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
user2 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Different Name",
|
||||
}
|
||||
|
||||
hash1 := GetDBUserCRC32(user1)
|
||||
hash2 := GetDBUserCRC32(user2)
|
||||
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("CRC32 should be equal when only DisplayName differs: hash1=%d, hash2=%d", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBUserCRC32_ChangesWhenAdminChanges(t *testing.T) {
|
||||
user1 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
user2 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: false,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
hash1 := GetDBUserCRC32(user1)
|
||||
hash2 := GetDBUserCRC32(user2)
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("CRC32 should change when Admin field changes: hash1=%d, hash2=%d", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBUserCRC32_ChangesWhenPasswordChanges(t *testing.T) {
|
||||
user1 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password1",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
user2 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password2",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
hash1 := GetDBUserCRC32(user1)
|
||||
hash2 := GetDBUserCRC32(user2)
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("CRC32 should change when Password field changes: hash1=%d, hash2=%d", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBUserCRC32_ChangesWhenAuthSourceChanges(t *testing.T) {
|
||||
user1 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
user2 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "12345",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
hash1 := GetDBUserCRC32(user1)
|
||||
hash2 := GetDBUserCRC32(user2)
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("CRC32 should change when AuthSource field changes: hash1=%d, hash2=%d", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBUserCRC32_ChangesWhenAuthSubjectChanges(t *testing.T) {
|
||||
user1 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "12345",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
user2 := model.User{
|
||||
Username: "testuser",
|
||||
Password: "password",
|
||||
PasswordHash: "hash",
|
||||
Admin: true,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "67890",
|
||||
DisplayName: "Test User",
|
||||
}
|
||||
|
||||
hash1 := GetDBUserCRC32(user1)
|
||||
hash2 := GetDBUserCRC32(user2)
|
||||
|
||||
if hash1 == hash2 {
|
||||
t.Errorf("CRC32 should change when AuthSubject field changes: hash1=%d, hash2=%d", hash1, hash2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDBUserCRC32_DistinguishesEmbeddedNulls(t *testing.T) {
|
||||
user1 := model.User{
|
||||
Username: "user\x00name",
|
||||
Password: "pass",
|
||||
PasswordHash: "hash",
|
||||
Admin: false,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "subject",
|
||||
}
|
||||
|
||||
user2 := model.User{
|
||||
Username: "user",
|
||||
Password: "\x00namepass",
|
||||
PasswordHash: "hash",
|
||||
Admin: false,
|
||||
AuthSource: "github",
|
||||
AuthSubject: "subject",
|
||||
}
|
||||
|
||||
if GetDBUserCRC32(user1) == GetDBUserCRC32(user2) {
|
||||
t.Fatal("CRC32 should distinguish values even when fields contain embedded null bytes")
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue