From bdbb9250b8032f88815aeebc2850036f08898573 Mon Sep 17 00:00:00 2001 From: devcxl <64475363+devcxl@users.noreply.github.com> Date: Sat, 4 Apr 2026 07:34:17 +0800 Subject: [PATCH] feat: add GitHub auth configuration and session foundation --- handler/routes.go | 42 +--- handler/session.go | 47 ++++ model/user.go | 3 + store/jsondb/auth_identity_test.go | 374 +++++++++++++++++++++++++++++ store/jsondb/jsondb.go | 77 +++++- store/store.go | 2 + util/auth_config.go | 124 ++++++++++ util/auth_config_test.go | 309 ++++++++++++++++++++++++ util/cache.go | 2 - util/config.go | 40 +-- util/util.go | 28 ++- util/util_test.go | 176 ++++++++++++++ 12 files changed, 1157 insertions(+), 67 deletions(-) create mode 100644 store/jsondb/auth_identity_test.go create mode 100644 util/auth_config.go create mode 100644 util/auth_config_test.go create mode 100644 util/util_test.go diff --git a/handler/routes.go b/handler/routes.go index 9a2bd9a..1e4109f 100644 --- a/handler/routes.go +++ b/handler/routes.go @@ -13,8 +13,6 @@ import ( "strings" "time" - "github.com/gorilla/sessions" - "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" "github.com/labstack/gommon/log" "github.com/rs/xid" @@ -90,43 +88,10 @@ func Login(db store.IStore) echo.HandlerFunc { } if userCorrect && passwordCorrect { - ageMax := 0 - if rememberMe { - ageMax = 86400 * 7 + if err := establishAuthenticatedSession(c, dbuser, rememberMe); err != nil { + log.Errorf("Failed to establish authenticated session: %v", err) + return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Failed to establish session"}) } - - cookiePath := util.GetCookiePath() - - sess, _ := session.Get("session", c) - sess.Options = &sessions.Options{ - Path: cookiePath, - MaxAge: ageMax, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - } - - // set session_token - tokenUID := xid.New().String() - now := time.Now().UTC().Unix() - sess.Values["username"] = dbuser.Username - sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser) - sess.Values["admin"] = dbuser.Admin - sess.Values["session_token"] = tokenUID - sess.Values["max_age"] = ageMax - sess.Values["created_at"] = now - sess.Values["updated_at"] = now - sess.Save(c.Request(), c.Response()) - - // set session_token in cookie - cookie := new(http.Cookie) - cookie.Name = "session_token" - cookie.Path = cookiePath - cookie.Value = tokenUID - cookie.MaxAge = ageMax - cookie.HttpOnly = true - cookie.SameSite = http.SameSiteLaxMode - c.SetCookie(cookie) - return c.JSON(http.StatusOK, jsonHTTPResponse{true, "Logged in successfully"}) } @@ -588,6 +553,7 @@ func UpdateClient(db store.IStore) echo.HandlerFunc { } client := *clientData.Client + allocatedIPs, err := util.GetAllocatedIPs("") check, err := util.ValidateIPAllocation(server.Interface.Addresses, allocatedIPs, _client.AllocatedIPs) if !check { return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, fmt.Sprintf("%s", err)}) diff --git a/handler/session.go b/handler/session.go index b660d9c..bdb6e5e 100644 --- a/handler/session.go +++ b/handler/session.go @@ -8,6 +8,9 @@ import ( "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" + "github.com/rs/xid" + + "github.com/ngoduykhanh/wireguard-ui/model" "github.com/ngoduykhanh/wireguard-ui/util" ) @@ -247,3 +250,47 @@ func clearSession(c echo.Context) { cookie.SameSite = http.SameSiteLaxMode c.SetCookie(cookie) } + +func establishAuthenticatedSession(c echo.Context, user model.User, rememberMe bool) error { + sess, err := session.Get("session", c) + if err != nil { + return err + } + + ageMax := 0 + if rememberMe { + ageMax = 86400 * 7 + } + + cookiePath := util.GetCookiePath() + + sess.Options = &sessions.Options{ + Path: cookiePath, + MaxAge: ageMax, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + } + + tokenUID := xid.New().String() + now := time.Now().UTC().Unix() + + sess.Values["username"] = user.Username + sess.Values["user_hash"] = util.GetDBUserCRC32(user) + sess.Values["admin"] = user.Admin + sess.Values["session_token"] = tokenUID + sess.Values["max_age"] = ageMax + sess.Values["created_at"] = now + sess.Values["updated_at"] = now + sess.Save(c.Request(), c.Response()) + + cookie := new(http.Cookie) + cookie.Name = "session_token" + cookie.Path = cookiePath + cookie.Value = tokenUID + cookie.MaxAge = ageMax + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteLaxMode + c.SetCookie(cookie) + + return nil +} diff --git a/model/user.go b/model/user.go index 71f4d13..eeb31cb 100644 --- a/model/user.go +++ b/model/user.go @@ -7,4 +7,7 @@ type User struct { // PasswordHash takes precedence over Password. PasswordHash string `json:"password_hash"` Admin bool `json:"admin"` + AuthSource string `json:"auth_source"` + AuthSubject string `json:"auth_subject"` + DisplayName string `json:"display_name"` } diff --git a/store/jsondb/auth_identity_test.go b/store/jsondb/auth_identity_test.go new file mode 100644 index 0000000..19ffa1e --- /dev/null +++ b/store/jsondb/auth_identity_test.go @@ -0,0 +1,374 @@ +package jsondb + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ngoduykhanh/wireguard-ui/model" + "github.com/sdomino/scribble" +) + +func TestGetUserByAuthIdentity_FindsGitHubUser(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + githubUser := model.User{ + Username: "github-user", + Password: "", + PasswordHash: "", + Admin: false, + AuthSource: "github", + AuthSubject: "123456", + DisplayName: "GitHub User", + } + + if err := db.SaveUser(githubUser); err != nil { + t.Fatalf("failed to save user: %v", err) + } + + found, err := db.GetUserByAuthIdentity("github", "123456") + if err != nil { + t.Fatalf("GetUserByAuthIdentity failed: %v", err) + } + + if found.Username != "github-user" { + t.Errorf("expected username 'github-user', got %q", found.Username) + } + if found.AuthSource != "github" { + t.Errorf("expected auth_source 'github', got %q", found.AuthSource) + } + if found.AuthSubject != "123456" { + t.Errorf("expected auth_subject '123456', got %q", found.AuthSubject) + } +} + +func TestGetUserByAuthIdentity_NotFound(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + _, err = db.GetUserByAuthIdentity("github", "nonexistent") + if err == nil { + t.Error("expected error for nonexistent auth identity") + } +} + +func TestGetUserByAuthIdentity_LocalUserNotFound(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + localUser := model.User{ + Username: "local-user", + Password: "secret", + PasswordHash: "hash", + Admin: false, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Local User", + } + + if err := db.SaveUser(localUser); err != nil { + t.Fatalf("failed to save user: %v", err) + } + + _, err = db.GetUserByAuthIdentity("local", "") + if err == nil { + t.Error("expected error for local user queried by empty auth identity") + } +} + +func TestReplaceUser_RenameKeepsOnlyNewRecord(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + oldUser := model.User{ + Username: "oldname", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Old Name", + } + + if err := db.SaveUser(oldUser); err != nil { + t.Fatalf("failed to save user: %v", err) + } + + newUser := model.User{ + Username: "newname", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "New Name", + } + + if err := db.ReplaceUser("oldname", newUser); err != nil { + t.Fatalf("ReplaceUser failed: %v", err) + } + + _, err = db.GetUserByName("oldname") + if err == nil { + t.Error("old username should not exist after rename") + } + + _, err = db.GetUserByName("newname") + if err != nil { + t.Errorf("new username should exist after rename: %v", err) + } + + users, err := db.GetUsers() + if err != nil { + t.Fatalf("GetUsers failed: %v", err) + } + if len(users) != 1 { + t.Errorf("expected exactly 1 user after rename, got %d", len(users)) + } +} + +func TestReplaceUser_SameNameUpdatesRecord(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + originalUser := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Original", + } + + if err := db.SaveUser(originalUser); err != nil { + t.Fatalf("failed to save user: %v", err) + } + + updatedUser := model.User{ + Username: "testuser", + Password: "newpassword", + PasswordHash: "newhash", + Admin: false, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Updated", + } + + if err := db.ReplaceUser("testuser", updatedUser); err != nil { + t.Fatalf("ReplaceUser failed: %v", err) + } + + found, err := db.GetUserByName("testuser") + if err != nil { + t.Fatalf("GetUserByName failed: %v", err) + } + + if found.Admin != false { + t.Errorf("expected admin=false, got %v", found.Admin) + } + if found.DisplayName != "Updated" { + t.Errorf("expected display_name='Updated', got %q", found.DisplayName) + } + + users, err := db.GetUsers() + if err != nil { + t.Fatalf("GetUsers failed: %v", err) + } + if len(users) != 1 { + t.Errorf("expected exactly 1 user after update, got %d", len(users)) + } +} + +func TestReplaceUser_MigratesLegacyLocalUser(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + legacyUser := model.User{ + Username: "legacyuser", + Password: "oldpassword", + PasswordHash: "oldhash", + Admin: false, + AuthSource: "", + AuthSubject: "", + DisplayName: "", + } + + if err := db.SaveUser(legacyUser); err != nil { + t.Fatalf("failed to save legacy user: %v", err) + } + + migratedUser := model.User{ + Username: "legacyuser", + Password: "", + PasswordHash: "", + Admin: false, + AuthSource: "github", + AuthSubject: "987654", + DisplayName: "Migrated User", + } + + if err := db.ReplaceUser("legacyuser", migratedUser); err != nil { + t.Fatalf("ReplaceUser failed: %v", err) + } + + found, err := db.GetUserByAuthIdentity("github", "987654") + if err != nil { + t.Fatalf("GetUserByAuthIdentity failed: %v", err) + } + + if found.Username != "legacyuser" { + t.Errorf("expected username 'legacyuser', got %q", found.Username) + } + if found.AuthSource != "github" { + t.Errorf("expected auth_source 'github', got %q", found.AuthSource) + } + if found.AuthSubject != "987654" { + t.Errorf("expected auth_subject '987654', got %q", found.AuthSubject) + } + if found.Password != "" || found.PasswordHash != "" { + t.Errorf("expected empty password fields for migrated user, got password=%q, password_hash=%q", found.Password, found.PasswordHash) + } +} + +func TestReplaceUser_WriteNewFails_KeepsOldRecord(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + existingUser := model.User{ + Username: "existing", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Existing User", + } + + if err := db.SaveUser(existingUser); err != nil { + t.Fatalf("failed to save user: %v", err) + } + + userDir := filepath.Join(tmpDir, "users") + os.Chmod(userDir, 0000) + + newUser := model.User{ + Username: "newuser", + Password: "password", + PasswordHash: "hash", + Admin: false, + AuthSource: "github", + AuthSubject: "123456", + DisplayName: "New User", + } + + err = db.ReplaceUser("existing", newUser) + + os.Chmod(userDir, 0700) + + if err == nil { + t.Error("expected error when writing new record fails") + } + + _, err = db.GetUserByName("existing") + if err != nil { + t.Errorf("old record should still exist after failed replace: %v", err) + } + + _, err = db.GetUserByName("newuser") + if err == nil { + t.Error("new record should not exist after failed replace") + } +} + +func TestReplaceUser_DeleteOldFails_AttemptsRollbackBeforeManualCleanupError(t *testing.T) { + tmpDir := t.TempDir() + db, err := New(tmpDir) + if err != nil { + t.Fatalf("failed to create db: %v", err) + } + + oldUser := model.User{ + Username: "olduser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Old User", + } + + if err := db.SaveUser(oldUser); err != nil { + t.Fatalf("failed to save user: %v", err) + } + + newUser := model.User{ + Username: "newuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "New User", + } + + origDeleteUserFile := deleteUserFile + deleteUserFile = func(conn *scribble.Driver, username string) error { + if username == "olduser" { + return fmt.Errorf("simulated delete failure for olduser") + } + if username == "newuser" { + return fmt.Errorf("simulated rollback failure for newuser") + } + return origDeleteUserFile(conn, username) + } + defer func() { deleteUserFile = origDeleteUserFile }() + + err = db.ReplaceUser("olduser", newUser) + + if err == nil { + t.Fatal("expected error when delete old fails and rollback also fails") + } + + if !strings.Contains(err.Error(), "manual cleanup required") { + t.Errorf("error should contain 'manual cleanup required', got: %v", err) + } + + if !strings.Contains(err.Error(), "rollback failed") { + t.Errorf("error should contain 'rollback failed', got: %v", err) + } + + _, err = db.GetUserByName("olduser") + if err != nil { + t.Errorf("old record should still exist: %v", err) + } + + _, err = db.GetUserByName("newuser") + if err != nil { + t.Errorf("new record should still exist after failed rollback: %v", err) + } +} diff --git a/store/jsondb/jsondb.go b/store/jsondb/jsondb.go index f4bab44..768790a 100644 --- a/store/jsondb/jsondb.go +++ b/store/jsondb/jsondb.go @@ -16,6 +16,10 @@ import ( "github.com/ngoduykhanh/wireguard-ui/util" ) +var deleteUserFile = func(conn *scribble.Driver, username string) error { + return conn.Delete("users", username) +} + type JsonDB struct { conn *scribble.Driver dbPath string @@ -216,7 +220,78 @@ func (o *JsonDB) SaveUser(user model.User) error { // DeleteUser func to remove user from the database func (o *JsonDB) DeleteUser(username string) error { delete(util.DBUsersToCRC32, username) - return o.conn.Delete("users", username) + return deleteUserFile(o.conn, username) +} + +// GetUserByAuthIdentity func to get a user by auth source and subject +func (o *JsonDB) GetUserByAuthIdentity(authSource, authSubject string) (model.User, error) { + if authSubject == "" { + return model.User{}, fmt.Errorf("auth subject cannot be empty for auth source %q", authSource) + } + + users, err := o.GetUsers() + if err != nil { + return model.User{}, err + } + + for _, user := range users { + if user.AuthSource == authSource && user.AuthSubject == authSubject { + return user, nil + } + } + + return model.User{}, fmt.Errorf("user not found for auth source %q and subject %q", authSource, authSubject) +} + +// ReplaceUser safely replaces or renames a user record. +// If oldUsername == user.Username, it performs an update. +// If they differ, it writes the new record first, then deletes the old one. +// If the old record deletion fails after the new record was written, the new +// record remains and an error indicating manual cleanup is returned. +func (o *JsonDB) ReplaceUser(oldUsername string, user model.User) error { + userPath := path.Join(path.Join(o.dbPath, "users"), user.Username+".json") + oldUserPath := path.Join(path.Join(o.dbPath, "users"), oldUsername+".json") + + isRename := oldUsername != user.Username + + if isRename { + _, err := os.Stat(oldUserPath) + if err != nil { + return fmt.Errorf("old user record does not exist: %w", err) + } + + _, err = os.Stat(userPath) + if err == nil { + return fmt.Errorf("cannot rename to %q: a record with that name already exists", user.Username) + } + } + + if err := o.conn.Write("users", user.Username, user); err != nil { + return fmt.Errorf("failed to write new user record: %w", err) + } + + if err := util.ManagePerms(userPath); err != nil { + os.Remove(userPath) + return fmt.Errorf("failed to set permissions on new user record: %w", err) + } + + util.DBUsersToCRC32[user.Username] = util.GetDBUserCRC32(user) + + if isRename { + delete(util.DBUsersToCRC32, oldUsername) + + if err := deleteUserFile(o.conn, oldUsername); err != nil { + rollbackErr := deleteUserFile(o.conn, user.Username) + if rollbackErr != nil { + return fmt.Errorf("failed to delete old user record %q (new record %q written): %v; rollback failed: %v; manual cleanup required: remove %s and ensure %s exists", oldUsername, user.Username, err, rollbackErr, userPath, oldUserPath) + } + delete(util.DBUsersToCRC32, user.Username) + os.RemoveAll(userPath) + return fmt.Errorf("failed to delete old user record %q (new record %q was written and rolled back): %v", oldUsername, user.Username, err) + } + } + + return nil } // GetGlobalSettings func to query global settings from the database diff --git a/store/store.go b/store/store.go index ef6d723..4bc76a2 100644 --- a/store/store.go +++ b/store/store.go @@ -8,7 +8,9 @@ type IStore interface { Init() error GetUsers() ([]model.User, error) GetUserByName(username string) (model.User, error) + GetUserByAuthIdentity(authSource, authSubject string) (model.User, error) SaveUser(user model.User) error + ReplaceUser(oldUsername string, user model.User) error DeleteUser(username string) error GetGlobalSettings() (model.GlobalSetting, error) GetServer() (model.Server, error) diff --git a/util/auth_config.go b/util/auth_config.go new file mode 100644 index 0000000..1e4b89f --- /dev/null +++ b/util/auth_config.go @@ -0,0 +1,124 @@ +package util + +import ( + "errors" + "fmt" + "os" + "strings" +) + +type AuthMethod string + +const ( + AuthMethodLocal AuthMethod = "local" + AuthMethodGitHub AuthMethod = "github" +) + +const ( + AuthMethodEnvVar = "WGUI_AUTH_METHOD" + GitHubClientIDEnvVar = "WGUI_GITHUB_CLIENT_ID" + GitHubClientSecretEnvVar = "WGUI_GITHUB_CLIENT_SECRET" + GitHubClientSecretFileEnvVar = "WGUI_GITHUB_CLIENT_SECRET_FILE" + GitHubRedirectURLEnvVar = "WGUI_GITHUB_REDIRECT_URL" + GitHubAllowedUsersEnvVar = "WGUI_GITHUB_ALLOWED_USERS" + GitHubAllowedOrgsEnvVar = "WGUI_GITHUB_ALLOWED_ORGS" + GitHubAdminUsersEnvVar = "WGUI_GITHUB_ADMIN_USERS" +) + +type GitHubAuthConfig struct { + ClientID string + ClientSecret string + ClientSecretFile string + RedirectURL string + AllowedUsers []string + AllowedOrgs []string + AdminUsers []string +} + +func ParseAuthMethod() AuthMethod { + method := LookupEnvOrString(AuthMethodEnvVar, "local") + switch strings.ToLower(method) { + case "github": + return AuthMethodGitHub + default: + return AuthMethodLocal + } +} + +func ParseGitHubAuthConfig() GitHubAuthConfig { + clientID := LookupEnvOrString(GitHubClientIDEnvVar, "") + clientSecret := LookupEnvOrString(GitHubClientSecretEnvVar, "") + clientSecretFile := LookupEnvOrString(GitHubClientSecretFileEnvVar, "") + redirectURL := LookupEnvOrString(GitHubRedirectURLEnvVar, "") + allowedUsersRaw := LookupEnvOrString(GitHubAllowedUsersEnvVar, "") + allowedOrgsRaw := LookupEnvOrString(GitHubAllowedOrgsEnvVar, "") + adminUsersRaw := LookupEnvOrString(GitHubAdminUsersEnvVar, "") + + var allowedUsers, allowedOrgs, adminUsers []string + if allowedUsersRaw != "" { + allowedUsers = NormalizeUsernames(strings.Split(allowedUsersRaw, ",")) + } + if allowedOrgsRaw != "" { + allowedOrgs = NormalizeUsernames(strings.Split(allowedOrgsRaw, ",")) + } + if adminUsersRaw != "" { + adminUsers = NormalizeUsernames(strings.Split(adminUsersRaw, ",")) + } + + return GitHubAuthConfig{ + ClientID: clientID, + ClientSecret: clientSecret, + ClientSecretFile: clientSecretFile, + RedirectURL: redirectURL, + AllowedUsers: allowedUsers, + AllowedOrgs: allowedOrgs, + AdminUsers: adminUsers, + } +} + +func ValidateGitHubAuthConfig(config GitHubAuthConfig) error { + if config.ClientID == "" { + return errors.New("github auth config: client_id is required") + } + + if config.ClientSecret == "" && config.ClientSecretFile == "" { + return errors.New("github auth config: client_secret is required") + } + + if config.ClientSecretFile != "" { + secret, err := os.ReadFile(config.ClientSecretFile) + if err != nil { + return fmt.Errorf("github auth config: reading client_secret_file %q: %w", config.ClientSecretFile, err) + } + if len(strings.TrimSpace(string(secret))) == 0 { + return fmt.Errorf("github auth config: client_secret_file %q is empty", config.ClientSecretFile) + } + } + + if config.RedirectURL == "" { + return errors.New("github auth config: redirect_url is required") + } + + if len(config.AllowedUsers) == 0 && len(config.AllowedOrgs) == 0 { + return errors.New("github auth config: at least one of allowed_users or allowed_orgs is required") + } + + if len(config.AdminUsers) == 0 { + return errors.New("github auth config: admin_users is required") + } + + return nil +} + +func NormalizeUsernames(users []string) []string { + seen := make(map[string]bool) + result := make([]string, 0, len(users)) + for _, user := range users { + normalized := strings.ToLower(strings.TrimSpace(user)) + if normalized != "" && !seen[normalized] { + seen[normalized] = true + result = append(result, normalized) + } + } + return result +} diff --git a/util/auth_config_test.go b/util/auth_config_test.go new file mode 100644 index 0000000..2468270 --- /dev/null +++ b/util/auth_config_test.go @@ -0,0 +1,309 @@ +package util + +import ( + "os" + "path/filepath" + "testing" +) + +func TestParseAuthMethod_DefaultLocal(t *testing.T) { + t.Setenv("WGUI_AUTH_METHOD", "") + method := ParseAuthMethod() + if method != AuthMethodLocal { + t.Errorf("expected default auth method to be %q, got %q", AuthMethodLocal, method) + } +} + +func TestParseAuthMethod_LocalExplicit(t *testing.T) { + t.Setenv("WGUI_AUTH_METHOD", "local") + method := ParseAuthMethod() + if method != AuthMethodLocal { + t.Errorf("expected auth method to be %q, got %q", AuthMethodLocal, method) + } +} + +func TestParseAuthMethod_GitHubExplicit(t *testing.T) { + t.Setenv("WGUI_AUTH_METHOD", "github") + method := ParseAuthMethod() + if method != AuthMethodGitHub { + t.Errorf("expected auth method to be %q, got %q", AuthMethodGitHub, method) + } +} + +func TestParseAuthMethod_GitHubCaseInsensitive(t *testing.T) { + t.Setenv("WGUI_AUTH_METHOD", "GitHub") + method := ParseAuthMethod() + if method != AuthMethodGitHub { + t.Errorf("expected case-insensitive auth method to be %q, got %q", AuthMethodGitHub, method) + } +} + +func TestParseGitHubAuthConfig_NormalizesLists(t *testing.T) { + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", " User1 ,user1, USER2 ") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", " OrgA ,orga ") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", " Admin1 , admin1 ") + + config := ParseGitHubAuthConfig() + + if config.ClientID != "test-client-id" { + t.Fatalf("expected client_id to be parsed, got %q", config.ClientID) + } + if config.ClientSecret != "test-secret" { + t.Fatalf("expected client_secret to be parsed, got %q", config.ClientSecret) + } + if config.RedirectURL != "http://localhost:8080/callback" { + t.Fatalf("expected redirect_url to be parsed, got %q", config.RedirectURL) + } + + expectedUsers := []string{"user1", "user2"} + if len(config.AllowedUsers) != len(expectedUsers) { + t.Fatalf("expected %d allowed users, got %d", len(expectedUsers), len(config.AllowedUsers)) + } + for i, user := range expectedUsers { + if config.AllowedUsers[i] != user { + t.Fatalf("expected allowed_users[%d]=%q, got %q", i, user, config.AllowedUsers[i]) + } + } + + expectedOrgs := []string{"orga"} + if len(config.AllowedOrgs) != len(expectedOrgs) { + t.Fatalf("expected %d allowed orgs, got %d", len(expectedOrgs), len(config.AllowedOrgs)) + } + for i, org := range expectedOrgs { + if config.AllowedOrgs[i] != org { + t.Fatalf("expected allowed_orgs[%d]=%q, got %q", i, org, config.AllowedOrgs[i]) + } + } + + expectedAdmins := []string{"admin1"} + if len(config.AdminUsers) != len(expectedAdmins) { + t.Fatalf("expected %d admin users, got %d", len(expectedAdmins), len(config.AdminUsers)) + } + for i, admin := range expectedAdmins { + if config.AdminUsers[i] != admin { + t.Fatalf("expected admin_users[%d]=%q, got %q", i, admin, config.AdminUsers[i]) + } + } +} + +func TestValidateGitHubAuthConfig_MissingClientID(t *testing.T) { + t.Setenv("WGUI_GITHUB_CLIENT_ID", "") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{} + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for missing client_id, got nil") + } +} + +func TestValidateGitHubAuthConfig_MissingClientSecret(t *testing.T) { + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + } + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for missing client_secret and client_secret_file, got nil") + } +} + +func TestValidateGitHubAuthConfig_MissingClientSecretFile(t *testing.T) { + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "secret") + os.WriteFile(secretFile, []byte("test-secret"), 0600) + + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", secretFile) + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + ClientSecretFile: secretFile, + } + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for missing client_secret and client_secret_file with no readable content, got nil") + } +} + +func TestValidateGitHubAuthConfig_ClientSecretFileReadError(t *testing.T) { + nonExistentFile := "/nonexistent/path/to/secret" + + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", nonExistentFile) + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + ClientSecretFile: nonExistentFile, + } + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for unreadable client_secret_file, got nil") + } +} + +func TestValidateGitHubAuthConfig_MissingRedirectURL(t *testing.T) { + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "secret") + os.WriteFile(secretFile, []byte("test-secret"), 0600) + + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET_FILE", "") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + ClientSecret: "test-secret", + } + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for missing redirect_url, got nil") + } +} + +func TestValidateGitHubAuthConfig_MissingAllowRules(t *testing.T) { + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "secret") + os.WriteFile(secretFile, []byte("test-secret"), 0600) + + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + } + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for missing allow rules (allowed_users or allowed_orgs), got nil") + } +} + +func TestValidateGitHubAuthConfig_MissingAdminUsers(t *testing.T) { + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "secret") + os.WriteFile(secretFile, []byte("test-secret"), 0600) + + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "user1,user2") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + AllowedUsers: []string{"user1", "user2"}, + AllowedOrgs: []string{}, + } + err := ValidateGitHubAuthConfig(config) + if err == nil { + t.Error("expected error for missing admin_users, got nil") + } +} + +func TestValidateGitHubAuthConfig_ValidMinimal(t *testing.T) { + tmpDir := t.TempDir() + secretFile := filepath.Join(tmpDir, "secret") + os.WriteFile(secretFile, []byte("test-secret"), 0600) + + t.Setenv("WGUI_GITHUB_CLIENT_ID", "test-client-id") + t.Setenv("WGUI_GITHUB_CLIENT_SECRET", "test-secret") + t.Setenv("WGUI_GITHUB_REDIRECT_URL", "http://localhost:8080/callback") + t.Setenv("WGUI_GITHUB_ALLOWED_USERS", "user1,user2") + t.Setenv("WGUI_GITHUB_ALLOWED_ORGS", "") + t.Setenv("WGUI_GITHUB_ADMIN_USERS", "admin1") + + config := GitHubAuthConfig{ + ClientID: "test-client-id", + ClientSecret: "test-secret", + RedirectURL: "http://localhost:8080/callback", + AllowedUsers: []string{"user1", "user2"}, + AdminUsers: []string{"admin1"}, + } + err := ValidateGitHubAuthConfig(config) + if err != nil { + t.Errorf("expected no error for valid config, got %v", err) + } +} + +func TestNormalizeUsernames_LowercaseAndTrim(t *testing.T) { + input := []string{" User1 ", "USER2", " user3 "} + expected := []string{"user1", "user2", "user3"} + result := NormalizeUsernames(input) + if len(result) != len(expected) { + t.Fatalf("expected length %d, got %d", len(expected), len(result)) + } + for i, v := range result { + if v != expected[i] { + t.Errorf("expected[%d]=%q, got[%d]=%q", i, expected[i], i, v) + } + } +} + +func TestNormalizeUsernames_Deduplicate(t *testing.T) { + input := []string{"User1", "user1", "USER1", "User2"} + expected := []string{"user1", "user2"} + result := NormalizeUsernames(input) + if len(result) != len(expected) { + t.Fatalf("expected length %d, got %d", len(expected), len(result)) + } + for i, v := range result { + if v != expected[i] { + t.Errorf("expected[%d]=%q, got[%d]=%q", i, expected[i], i, v) + } + } +} + +func TestNormalizeUsernames_Empty(t *testing.T) { + input := []string{} + expected := []string{} + result := NormalizeUsernames(input) + if len(result) != len(expected) { + t.Fatalf("expected length %d, got %d", len(expected), len(result)) + } +} + +func TestNormalizeUsernames_BlankEntriesDropped(t *testing.T) { + result := NormalizeUsernames([]string{"", " ", " user1 "}) + if len(result) != 1 || result[0] != "user1" { + t.Fatalf("expected blank entries to be dropped, got %#v", result) + } +} diff --git a/util/cache.go b/util/cache.go index c247117..fcb5e28 100644 --- a/util/cache.go +++ b/util/cache.go @@ -1,6 +1,4 @@ package util -import "sync" - var IPToSubnetRange = map[string]uint16{} var DBUsersToCRC32 = map[string]uint32{} diff --git a/util/config.go b/util/config.go index 4af6bd2..469f562 100644 --- a/util/config.go +++ b/util/config.go @@ -9,25 +9,27 @@ import ( // Runtime config var ( - DisableLogin bool - BindAddress string - SmtpHostname string - SmtpPort int - SmtpUsername string - SmtpPassword string - SmtpNoTLSCheck bool - SmtpEncryption string - SmtpAuthType string - SmtpHelo string - SendgridApiKey string - EmailFrom string - EmailFromName string - SessionSecret [64]byte - SessionMaxDuration int64 - WgConfTemplate string - BasePath string - SubnetRanges map[string]([]*net.IPNet) - SubnetRangesOrder []string + DisableLogin bool + BindAddress string + SmtpHostname string + SmtpPort int + SmtpUsername string + SmtpPassword string + SmtpNoTLSCheck bool + SmtpEncryption string + SmtpAuthType string + SmtpHelo string + SendgridApiKey string + EmailFrom string + EmailFromName string + SessionSecret [64]byte + SessionMaxDuration int64 + WgConfTemplate string + BasePath string + SubnetRanges map[string]([]*net.IPNet) + SubnetRangesOrder []string + CurrentAuthMethod AuthMethod + CurrentGitHubConfig GitHubAuthConfig ) const ( diff --git a/util/util.go b/util/util.go index 3b8b7be..5c2fac6 100644 --- a/util/util.go +++ b/util/util.go @@ -2,8 +2,7 @@ package util import ( "bufio" - "bytes" - "encoding/gob" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -727,13 +726,28 @@ func ManagePerms(path string) error { return err } +// GetDBUserCRC32 builds a stable session hash from security-relevant user fields. +// DisplayName is intentionally excluded so profile presentation changes do not invalidate sessions. func GetDBUserCRC32(dbuser model.User) uint32 { - buf := new(bytes.Buffer) - enc := gob.NewEncoder(buf) - if err := enc.Encode(dbuser); err != nil { - panic("model.User is gob-incompatible, session verification is impossible") + h := crc32.NewIEEE() + writeHashField(h, dbuser.Username) + if dbuser.Admin { + writeHashField(h, "1") + } else { + writeHashField(h, "0") } - return crc32.ChecksumIEEE(buf.Bytes()) + writeHashField(h, dbuser.Password) + writeHashField(h, dbuser.PasswordHash) + writeHashField(h, dbuser.AuthSource) + writeHashField(h, dbuser.AuthSubject) + return h.Sum32() +} + +func writeHashField(h io.Writer, value string) { + var length [4]byte + binary.BigEndian.PutUint32(length[:], uint32(len(value))) + _, _ = h.Write(length[:]) + _, _ = h.Write([]byte(value)) } func ConcatMultipleSlices(slices ...[]byte) []byte { diff --git a/util/util_test.go b/util/util_test.go new file mode 100644 index 0000000..e487b8a --- /dev/null +++ b/util/util_test.go @@ -0,0 +1,176 @@ +package util + +import ( + "testing" + + "github.com/ngoduykhanh/wireguard-ui/model" +) + +func TestGetDBUserCRC32_IgnoresDisplayName(t *testing.T) { + user1 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Test User", + } + + user2 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Different Name", + } + + hash1 := GetDBUserCRC32(user1) + hash2 := GetDBUserCRC32(user2) + + if hash1 != hash2 { + t.Errorf("CRC32 should be equal when only DisplayName differs: hash1=%d, hash2=%d", hash1, hash2) + } +} + +func TestGetDBUserCRC32_ChangesWhenAdminChanges(t *testing.T) { + user1 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Test User", + } + + user2 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: false, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Test User", + } + + hash1 := GetDBUserCRC32(user1) + hash2 := GetDBUserCRC32(user2) + + if hash1 == hash2 { + t.Errorf("CRC32 should change when Admin field changes: hash1=%d, hash2=%d", hash1, hash2) + } +} + +func TestGetDBUserCRC32_ChangesWhenPasswordChanges(t *testing.T) { + user1 := model.User{ + Username: "testuser", + Password: "password1", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Test User", + } + + user2 := model.User{ + Username: "testuser", + Password: "password2", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Test User", + } + + hash1 := GetDBUserCRC32(user1) + hash2 := GetDBUserCRC32(user2) + + if hash1 == hash2 { + t.Errorf("CRC32 should change when Password field changes: hash1=%d, hash2=%d", hash1, hash2) + } +} + +func TestGetDBUserCRC32_ChangesWhenAuthSourceChanges(t *testing.T) { + user1 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "local", + AuthSubject: "", + DisplayName: "Test User", + } + + user2 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "github", + AuthSubject: "12345", + DisplayName: "Test User", + } + + hash1 := GetDBUserCRC32(user1) + hash2 := GetDBUserCRC32(user2) + + if hash1 == hash2 { + t.Errorf("CRC32 should change when AuthSource field changes: hash1=%d, hash2=%d", hash1, hash2) + } +} + +func TestGetDBUserCRC32_ChangesWhenAuthSubjectChanges(t *testing.T) { + user1 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "github", + AuthSubject: "12345", + DisplayName: "Test User", + } + + user2 := model.User{ + Username: "testuser", + Password: "password", + PasswordHash: "hash", + Admin: true, + AuthSource: "github", + AuthSubject: "67890", + DisplayName: "Test User", + } + + hash1 := GetDBUserCRC32(user1) + hash2 := GetDBUserCRC32(user2) + + if hash1 == hash2 { + t.Errorf("CRC32 should change when AuthSubject field changes: hash1=%d, hash2=%d", hash1, hash2) + } +} + +func TestGetDBUserCRC32_DistinguishesEmbeddedNulls(t *testing.T) { + user1 := model.User{ + Username: "user\x00name", + Password: "pass", + PasswordHash: "hash", + Admin: false, + AuthSource: "github", + AuthSubject: "subject", + } + + user2 := model.User{ + Username: "user", + Password: "\x00namepass", + PasswordHash: "hash", + Admin: false, + AuthSource: "github", + AuthSubject: "subject", + } + + if GetDBUserCRC32(user1) == GetDBUserCRC32(user2) { + t.Fatal("CRC32 should distinguish values even when fields contain embedded null bytes") + } +}