From e052f400aad22e0f7baf4645e9af9231fad82660 Mon Sep 17 00:00:00 2001 From: Christoph Haas Date: Thu, 22 Apr 2021 20:29:37 +0200 Subject: [PATCH] convert all email addresses to lower case (#14) --- .../providers/password/provider.go | 1 + internal/common/db.go | 50 +++++++++++++++---- internal/server/version.go | 2 +- internal/users/manager.go | 11 ++++ internal/wireguard/peermanager.go | 4 ++ 5 files changed, 58 insertions(+), 10 deletions(-) diff --git a/internal/authentication/providers/password/provider.go b/internal/authentication/providers/password/provider.go index 9767d32..53ad77b 100644 --- a/internal/authentication/providers/password/provider.go +++ b/internal/authentication/providers/password/provider.go @@ -108,6 +108,7 @@ func (provider Provider) GetUserModel(ctx *authentication.AuthContext) (*authent } func (provider Provider) InitializeAdmin(email, password string) error { + email = strings.ToLower(email) if !emailRegex.MatchString(email) { return errors.New("admin username must be an email address") } diff --git a/internal/common/db.go b/internal/common/db.go index 7c8d5a1..63e7f7c 100644 --- a/internal/common/db.go +++ b/internal/common/db.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "time" "github.com/pkg/errors" @@ -14,6 +15,22 @@ import ( "gorm.io/gorm/logger" ) +func init() { + migrations = append(migrations, Migration{ + version: "1.0.7", + migrateFn: func(db *gorm.DB) error { + if err := db.Exec("UPDATE users SET email = LOWER(email)").Error; err != nil { + return errors.Wrap(err, "failed to convert user emails to lower case") + } + if err := db.Exec("UPDATE peers SET email = LOWER(email)").Error; err != nil { + return errors.Wrap(err, "failed to convert peer emails to lower case") + } + logrus.Infof("upgraded database format to version 1.0.7") + return nil + }, + }) +} + type SupportedDatabase string const ( @@ -80,16 +97,18 @@ type DatabaseMigrationInfo struct { Applied time.Time } +type Migration struct { + version string + migrateFn func(db *gorm.DB) error +} + +var migrations []Migration + func MigrateDatabase(db *gorm.DB, version string) error { if err := db.AutoMigrate(&DatabaseMigrationInfo{}); err != nil { return errors.Wrap(err, "failed to migrate version database") } - newVersion := DatabaseMigrationInfo{ - Version: version, - Applied: time.Now(), - } - existingMigration := DatabaseMigrationInfo{} db.Where("version = ?", version).FirstOrInit(&existingMigration) @@ -97,11 +116,24 @@ func MigrateDatabase(db *gorm.DB, version string) error { lastVersion := DatabaseMigrationInfo{} db.Order("applied desc, version desc").FirstOrInit(&lastVersion) - // TODO: migrate database + sort.Slice(migrations, func(i, j int) bool { + return migrations[i].version < migrations[j].version + }) - res := db.Create(&newVersion) - if res.Error != nil { - return errors.Wrap(res.Error, "failed to write version to database") + for _, migration := range migrations { + if migration.version > lastVersion.Version { + if err := migration.migrateFn(db); err != nil { + return errors.Wrapf(err, "failed to migrate to version %s", migration.version) + } + + res := db.Create(&DatabaseMigrationInfo{ + Version: migration.version, + Applied: time.Now(), + }) + if res.Error != nil { + return errors.Wrapf(res.Error, "failed to write version %s to database", migration.version) + } + } } } diff --git a/internal/server/version.go b/internal/server/version.go index 718153c..e91e867 100644 --- a/internal/server/version.go +++ b/internal/server/version.go @@ -1,4 +1,4 @@ package server var Version = "testbuild" -var DatabaseVersion = "1.0.6" +var DatabaseVersion = "1.0.7" diff --git a/internal/users/manager.go b/internal/users/manager.go index 8cd9fe7..1287c72 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -51,6 +51,8 @@ func (m Manager) UserExists(email string) bool { } func (m Manager) GetUser(email string) *User { + email = strings.ToLower(email) + user := User{} m.db.Where("email = ?", email).First(&user) @@ -62,6 +64,8 @@ func (m Manager) GetUser(email string) *User { } func (m Manager) GetUserUnscoped(email string) *User { + email = strings.ToLower(email) + user := User{} m.db.Unscoped().Where("email = ?", email).First(&user) @@ -93,6 +97,8 @@ func (m Manager) GetFilteredAndSortedUsersUnscoped(sortKey, sortDirection, searc } func (m Manager) GetOrCreateUser(email string) (*User, error) { + email = strings.ToLower(email) + user := User{} m.db.Where("email = ?", email).FirstOrInit(&user) @@ -113,6 +119,8 @@ func (m Manager) GetOrCreateUser(email string) (*User, error) { } func (m Manager) GetOrCreateUserUnscoped(email string) (*User, error) { + email = strings.ToLower(email) + user := User{} m.db.Unscoped().Where("email = ?", email).FirstOrInit(&user) @@ -133,6 +141,7 @@ func (m Manager) GetOrCreateUserUnscoped(email string) (*User, error) { } func (m Manager) CreateUser(user *User) error { + user.Email = strings.ToLower(user.Email) res := m.db.Create(user) if res.Error != nil { return errors.Wrapf(res.Error, "failed to create user %s", user.Email) @@ -142,6 +151,7 @@ func (m Manager) CreateUser(user *User) error { } func (m Manager) UpdateUser(user *User) error { + user.Email = strings.ToLower(user.Email) res := m.db.Save(user) if res.Error != nil { return errors.Wrapf(res.Error, "failed to update user %s", user.Email) @@ -151,6 +161,7 @@ func (m Manager) UpdateUser(user *User) error { } func (m Manager) DeleteUser(user *User) error { + user.Email = strings.ToLower(user.Email) res := m.db.Delete(user) if res.Error != nil { return errors.Wrapf(res.Error, "failed to update user %s", user.Email) diff --git a/internal/wireguard/peermanager.go b/internal/wireguard/peermanager.go index efc741a..3299bea 100644 --- a/internal/wireguard/peermanager.go +++ b/internal/wireguard/peermanager.go @@ -623,6 +623,7 @@ func (m *PeerManager) GetFilteredAndSortedPeers(device, sortKey, sortDirection, } func (m *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer { + email = strings.ToLower(email) peers := make([]Peer, 0) m.db.Where("email = ?", email).Find(&peers) @@ -691,6 +692,7 @@ func (m *PeerManager) GetPeerByKey(publicKey string) Peer { } func (m *PeerManager) GetPeersByMail(mail string) []Peer { + mail = strings.ToLower(mail) var peers []Peer m.db.Where("email = ?", mail).Find(&peers) for i := range peers { @@ -706,6 +708,7 @@ func (m *PeerManager) CreatePeer(peer Peer) error { peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) peer.UpdatedAt = time.Now() peer.CreatedAt = time.Now() + peer.Email = strings.ToLower(peer.Email) res := m.db.Create(&peer) if res.Error != nil { @@ -718,6 +721,7 @@ func (m *PeerManager) CreatePeer(peer Peer) error { func (m *PeerManager) UpdatePeer(peer Peer) error { peer.UpdatedAt = time.Now() + peer.Email = strings.ToLower(peer.Email) res := m.db.Save(&peer) if res.Error != nil {