chore: code adjustment
This commit is contained in:
		
							parent
							
								
									8cfe9a3d5b
								
							
						
					
					
						commit
						a12ed9bf51
					
				|  | @ -75,7 +75,8 @@ func Login(db store.IStore) echo.HandlerFunc { | ||||||
| 
 | 
 | ||||||
| 		dbuser, err := db.GetUserByName(username) | 		dbuser, err := db.GetUserByName(username) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Cannot query user from DB"}) | 			log.Infof("Cannot query user %s from DB", username) | ||||||
|  | 			return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "Invalid credentials"}) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		userCorrect := subtle.ConstantTimeCompare([]byte(username), []byte(dbuser.Username)) == 1 | 		userCorrect := subtle.ConstantTimeCompare([]byte(username), []byte(dbuser.Username)) == 1 | ||||||
|  | @ -173,7 +174,7 @@ func Logout() echo.HandlerFunc { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // LoadProfile to load user information
 | // LoadProfile to load user information
 | ||||||
| func LoadProfile(db store.IStore) echo.HandlerFunc { | func LoadProfile() echo.HandlerFunc { | ||||||
| 	return func(c echo.Context) error { | 	return func(c echo.Context) error { | ||||||
| 		return c.Render(http.StatusOK, "profile.html", map[string]interface{}{ | 		return c.Render(http.StatusOK, "profile.html", map[string]interface{}{ | ||||||
| 			"baseData": model.BaseData{Active: "profile", CurrentUser: currentUser(c), Admin: isAdmin(c)}, | 			"baseData": model.BaseData{Active: "profile", CurrentUser: currentUser(c), Admin: isAdmin(c)}, | ||||||
|  | @ -182,7 +183,7 @@ func LoadProfile(db store.IStore) echo.HandlerFunc { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // UsersSettings handler
 | // UsersSettings handler
 | ||||||
| func UsersSettings(db store.IStore) echo.HandlerFunc { | func UsersSettings() echo.HandlerFunc { | ||||||
| 	return func(c echo.Context) error { | 	return func(c echo.Context) error { | ||||||
| 		return c.Render(http.StatusOK, "users_settings.html", map[string]interface{}{ | 		return c.Render(http.StatusOK, "users_settings.html", map[string]interface{}{ | ||||||
| 			"baseData": model.BaseData{Active: "users-settings", CurrentUser: currentUser(c), Admin: isAdmin(c)}, | 			"baseData": model.BaseData{Active: "users-settings", CurrentUser: currentUser(c), Admin: isAdmin(c)}, | ||||||
|  |  | ||||||
							
								
								
									
										55
									
								
								main.go
								
								
								
								
							
							
						
						
									
										55
									
								
								main.go
								
								
								
								
							|  | @ -31,23 +31,23 @@ var ( | ||||||
| 	gitRef     = "N/A" | 	gitRef     = "N/A" | ||||||
| 	buildTime  = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05")) | 	buildTime  = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05")) | ||||||
| 	// configuration variables
 | 	// configuration variables
 | ||||||
| 	flagDisableLogin             bool   = false | 	flagDisableLogin             = false | ||||||
| 	flagBindAddress              string = "0.0.0.0:5000" | 	flagBindAddress              = "0.0.0.0:5000" | ||||||
| 	flagSmtpHostname             string = "127.0.0.1" | 	flagSmtpHostname             = "127.0.0.1" | ||||||
| 	flagSmtpPort                 int    = 25 | 	flagSmtpPort                 = 25 | ||||||
| 	flagSmtpUsername             string | 	flagSmtpUsername             string | ||||||
| 	flagSmtpPassword             string | 	flagSmtpPassword             string | ||||||
| 	flagSmtpAuthType             string = "NONE" | 	flagSmtpAuthType             = "NONE" | ||||||
| 	flagSmtpNoTLSCheck           bool   = false | 	flagSmtpNoTLSCheck           = false | ||||||
| 	flagSmtpEncryption           string = "STARTTLS" | 	flagSmtpEncryption           = "STARTTLS" | ||||||
| 	flagSmtpHelo                 string = "localhost" | 	flagSmtpHelo                 = "localhost" | ||||||
| 	flagSendgridApiKey           string | 	flagSendgridApiKey           string | ||||||
| 	flagEmailFrom                string | 	flagEmailFrom                string | ||||||
| 	flagEmailFromName            string = "WireGuard UI" | 	flagEmailFromName            = "WireGuard UI" | ||||||
| 	flagTelegramToken            string | 	flagTelegramToken            string | ||||||
| 	flagTelegramAllowConfRequest bool   = false | 	flagTelegramAllowConfRequest = false | ||||||
| 	flagTelegramFloodWait        int    = 60 | 	flagTelegramFloodWait        = 60 | ||||||
| 	flagSessionSecret            string = util.RandomString(32) | 	flagSessionSecret            = util.RandomString(32) | ||||||
| 	flagWgConfTemplate           string | 	flagWgConfTemplate           string | ||||||
| 	flagBasePath                 string | 	flagBasePath                 string | ||||||
| 	flagSubnetRanges             string | 	flagSubnetRanges             string | ||||||
|  | @ -95,7 +95,7 @@ func init() { | ||||||
| 
 | 
 | ||||||
| 	var ( | 	var ( | ||||||
| 		smtpPasswordLookup   = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword) | 		smtpPasswordLookup   = util.LookupEnvOrString("SMTP_PASSWORD", flagSmtpPassword) | ||||||
| 		sengridApiKeyLookup = util.LookupEnvOrString("SENDGRID_API_KEY", flagSendgridApiKey) | 		sendgridApiKeyLookup = util.LookupEnvOrString("SENDGRID_API_KEY", flagSendgridApiKey) | ||||||
| 		sessionSecretLookup  = util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret) | 		sessionSecretLookup  = util.LookupEnvOrString("SESSION_SECRET", flagSessionSecret) | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
|  | @ -106,9 +106,9 @@ func init() { | ||||||
| 		flag.StringVar(&flagSmtpPassword, "smtp-password", util.LookupEnvOrFile("SMTP_PASSWORD_FILE", flagSmtpPassword), "SMTP Password File") | 		flag.StringVar(&flagSmtpPassword, "smtp-password", util.LookupEnvOrFile("SMTP_PASSWORD_FILE", flagSmtpPassword), "SMTP Password File") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// check empty sengridApiKey env var
 | 	// check empty sendgridApiKey env var
 | ||||||
| 	if sengridApiKeyLookup != "" { | 	if sendgridApiKeyLookup != "" { | ||||||
| 		flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", sengridApiKeyLookup, "Your sendgrid api key.") | 		flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", sendgridApiKeyLookup, "Your sendgrid api key.") | ||||||
| 	} else { | 	} else { | ||||||
| 		flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", util.LookupEnvOrFile("SENDGRID_API_KEY_FILE", flagSendgridApiKey), "File containing your sendgrid api key.") | 		flag.StringVar(&flagSendgridApiKey, "sendgrid-api-key", util.LookupEnvOrFile("SENDGRID_API_KEY_FILE", flagSendgridApiKey), "File containing your sendgrid api key.") | ||||||
| 	} | 	} | ||||||
|  | @ -215,12 +215,12 @@ func main() { | ||||||
| 		app.GET(util.BasePath+"/login", handler.LoginPage()) | 		app.GET(util.BasePath+"/login", handler.LoginPage()) | ||||||
| 		app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson) | 		app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson) | ||||||
| 		app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession) | 		app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession) | ||||||
| 		app.GET(util.BasePath+"/profile", handler.LoadProfile(db), handler.ValidSession) | 		app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession) | ||||||
| 		app.GET(util.BasePath+"/users-settings", handler.UsersSettings(db), handler.ValidSession, handler.NeedsAdmin) | 		app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.NeedsAdmin) | ||||||
| 		app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson) | 		app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson) | ||||||
| 		app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) | 		app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) | ||||||
| 		app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) | 		app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin) | ||||||
| 		app.GET(util.BasePath+"/getusers", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin) | 		app.GET(util.BasePath+"/get-users", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin) | ||||||
| 		app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession) | 		app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -276,10 +276,13 @@ func main() { | ||||||
| 	if strings.HasPrefix(util.BindAddress, "unix://") { | 	if strings.HasPrefix(util.BindAddress, "unix://") { | ||||||
| 		// Listen on unix domain socket.
 | 		// Listen on unix domain socket.
 | ||||||
| 		// https://github.com/labstack/echo/issues/830
 | 		// https://github.com/labstack/echo/issues/830
 | ||||||
| 		syscall.Unlink(util.BindAddress[6:]) | 		err := syscall.Unlink(util.BindAddress[6:]) | ||||||
|  | 		if err != nil { | ||||||
|  | 			app.Logger.Fatalf("Cannot unlink unix socket: Error: %v", err) | ||||||
|  | 		} | ||||||
| 		l, err := net.Listen("unix", util.BindAddress[6:]) | 		l, err := net.Listen("unix", util.BindAddress[6:]) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			app.Logger.Fatal(err) | 			app.Logger.Fatalf("Cannot create unix socket. Error: %v", err) | ||||||
| 		} | 		} | ||||||
| 		app.Listener = l | 		app.Listener = l | ||||||
| 		app.Logger.Fatal(app.Start("")) | 		app.Logger.Fatal(app.Start("")) | ||||||
|  | @ -292,7 +295,7 @@ func main() { | ||||||
| func initServerConfig(db store.IStore, tmplDir fs.FS) { | func initServerConfig(db store.IStore, tmplDir fs.FS) { | ||||||
| 	settings, err := db.GetGlobalSettings() | 	settings, err := db.GetGlobalSettings() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Cannot get global settings: ", err) | 		log.Fatalf("Cannot get global settings: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if _, err := os.Stat(settings.ConfigFilePath); err == nil { | 	if _, err := os.Stat(settings.ConfigFilePath); err == nil { | ||||||
|  | @ -302,23 +305,23 @@ func initServerConfig(db store.IStore, tmplDir fs.FS) { | ||||||
| 
 | 
 | ||||||
| 	server, err := db.GetServer() | 	server, err := db.GetServer() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Cannot get server config: ", err) | 		log.Fatalf("Cannot get server config: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	clients, err := db.GetClients(false) | 	clients, err := db.GetClients(false) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Cannot get client config: ", err) | 		log.Fatalf("Cannot get client config: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	users, err := db.GetUsers() | 	users, err := db.GetUsers() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Cannot get user config: ", err) | 		log.Fatalf("Cannot get user config: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// write config file
 | 	// write config file
 | ||||||
| 	err = util.WriteWireGuardServerConfig(tmplDir, server, clients, users, settings) | 	err = util.WriteWireGuardServerConfig(tmplDir, server, clients, users, settings) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("Cannot create server config: ", err) | 		log.Fatalf("Cannot create server config: %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -37,14 +37,14 @@ func New(dbPath string) (*JsonDB, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (o *JsonDB) Init() error { | func (o *JsonDB) Init() error { | ||||||
| 	var clientPath string = path.Join(o.dbPath, "clients") | 	var clientPath = path.Join(o.dbPath, "clients") | ||||||
| 	var serverPath string = path.Join(o.dbPath, "server") | 	var serverPath = path.Join(o.dbPath, "server") | ||||||
| 	var userPath string = path.Join(o.dbPath, "users") | 	var userPath = path.Join(o.dbPath, "users") | ||||||
| 	var wakeOnLanHostsPath string = path.Join(o.dbPath, "wake_on_lan_hosts") | 	var wakeOnLanHostsPath = path.Join(o.dbPath, "wake_on_lan_hosts") | ||||||
| 	var serverInterfacePath string = path.Join(serverPath, "interfaces.json") | 	var serverInterfacePath = path.Join(serverPath, "interfaces.json") | ||||||
| 	var serverKeyPairPath string = path.Join(serverPath, "keypair.json") | 	var serverKeyPairPath = path.Join(serverPath, "keypair.json") | ||||||
| 	var globalSettingPath string = path.Join(serverPath, "global_settings.json") | 	var globalSettingPath = path.Join(serverPath, "global_settings.json") | ||||||
| 	var hashesPath string = path.Join(serverPath, "hashes.json") | 	var hashesPath = path.Join(serverPath, "hashes.json") | ||||||
| 
 | 
 | ||||||
| 	// create directories if they do not exist
 | 	// create directories if they do not exist
 | ||||||
| 	if _, err := os.Stat(clientPath); os.IsNotExist(err) { | 	if _, err := os.Stat(clientPath); os.IsNotExist(err) { | ||||||
|  | @ -189,7 +189,7 @@ func (o *JsonDB) GetUsers() ([]model.User, error) { | ||||||
| 	for _, i := range results { | 	for _, i := range results { | ||||||
| 		user := model.User{} | 		user := model.User{} | ||||||
| 
 | 
 | ||||||
| 		if err := json.Unmarshal([]byte(i), &user); err != nil { | 		if err := json.Unmarshal(i, &user); err != nil { | ||||||
| 			return users, fmt.Errorf("cannot decode user json structure: %v", err) | 			return users, fmt.Errorf("cannot decode user json structure: %v", err) | ||||||
| 		} | 		} | ||||||
| 		users = append(users, user) | 		users = append(users, user) | ||||||
|  | @ -267,7 +267,7 @@ func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) { | ||||||
| 		clientData := model.ClientData{} | 		clientData := model.ClientData{} | ||||||
| 
 | 
 | ||||||
| 		// get client info
 | 		// get client info
 | ||||||
| 		if err := json.Unmarshal([]byte(f), &client); err != nil { | 		if err := json.Unmarshal(f, &client); err != nil { | ||||||
| 			return clients, fmt.Errorf("cannot decode client json structure: %v", err) | 			return clients, fmt.Errorf("cannot decode client json structure: %v", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -278,7 +278,7 @@ func (o *JsonDB) GetClients(hasQRCode bool) ([]model.ClientData, error) { | ||||||
| 
 | 
 | ||||||
| 			png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) | 			png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) | ||||||
| 			if err == nil { | 			if err == nil { | ||||||
| 				clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) | 				clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png) | ||||||
| 			} else { | 			} else { | ||||||
| 				fmt.Print("Cannot generate QR code: ", err) | 				fmt.Print("Cannot generate QR code: ", err) | ||||||
| 			} | 			} | ||||||
|  | @ -315,7 +315,7 @@ func (o *JsonDB) GetClientByID(clientID string, qrCodeSettings model.QRCodeSetti | ||||||
| 
 | 
 | ||||||
| 		png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) | 		png, err := qrcode.Encode(util.BuildClientConfig(client, server, globalSettings), qrcode.Medium, 256) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString([]byte(png)) | 			clientData.QRCode = "data:image/png;base64," + base64.StdEncoding.EncodeToString(png) | ||||||
| 		} else { | 		} else { | ||||||
| 			fmt.Print("Cannot generate QR code: ", err) | 			fmt.Print("Cannot generate QR code: ", err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | @ -23,7 +23,7 @@ func (o *JsonDB) GetWakeOnLanHosts() ([]model.WakeOnLanHost, error) { | ||||||
| 		host := model.WakeOnLanHost{} | 		host := model.WakeOnLanHost{} | ||||||
| 
 | 
 | ||||||
| 		// get client info
 | 		// get client info
 | ||||||
| 		if err := json.Unmarshal([]byte(f), &host); err != nil { | 		if err := json.Unmarshal(f, &host); err != nil { | ||||||
| 			return hosts, fmt.Errorf("cannot decode client json structure: %v", err) | 			return hosts, fmt.Errorf("cannot decode client json structure: %v", err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -26,8 +26,8 @@ var ( | ||||||
| 	Bot      *echotron.API | 	Bot      *echotron.API | ||||||
| 	BotMutex sync.RWMutex | 	BotMutex sync.RWMutex | ||||||
| 
 | 
 | ||||||
| 	floodWait        = make(map[int64]int64, 0) | 	floodWait        = make(map[int64]int64) | ||||||
| 	floodMessageSent = make(map[int64]struct{}, 0) | 	floodMessageSent = make(map[int64]struct{}) | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func Start(initDeps TgBotInitDependencies) (err error) { | func Start(initDeps TgBotInitDependencies) (err error) { | ||||||
|  | @ -84,12 +84,15 @@ func Start(initDeps TgBotInitDependencies) (err error) { | ||||||
| 					continue | 					continue | ||||||
| 				} | 				} | ||||||
| 				floodMessageSent[userid] = struct{}{} | 				floodMessageSent[userid] = struct{}{} | ||||||
| 				bot.SendMessage( | 				_, err := bot.SendMessage( | ||||||
| 					fmt.Sprintf("You can only request your configs once per %d minutes", FloodWait), | 					fmt.Sprintf("You can only request your configs once per %d minutes", FloodWait), | ||||||
| 					userid, | 					userid, | ||||||
| 					&echotron.MessageOptions{ | 					&echotron.MessageOptions{ | ||||||
| 						ReplyToMessageID: update.Message.ID, | 						ReplyToMessageID: update.Message.ID, | ||||||
| 					}) | 					}) | ||||||
|  | 				if err != nil { | ||||||
|  | 					log.Errorf("Failed to send telegram message. Error %v", err) | ||||||
|  | 				} | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			floodWait[userid] = time.Now().Unix() | 			floodWait[userid] = time.Now().Unix() | ||||||
|  | @ -100,12 +103,15 @@ func Start(initDeps TgBotInitDependencies) (err error) { | ||||||
| 				for _, f := range failed { | 				for _, f := range failed { | ||||||
| 					messageText += f + "\n" | 					messageText += f + "\n" | ||||||
| 				} | 				} | ||||||
| 				bot.SendMessage( | 				_, err := bot.SendMessage( | ||||||
| 					messageText, | 					messageText, | ||||||
| 					userid, | 					userid, | ||||||
| 					&echotron.MessageOptions{ | 					&echotron.MessageOptions{ | ||||||
| 						ReplyToMessageID: update.Message.ID, | 						ReplyToMessageID: update.Message.ID, | ||||||
| 					}) | 					}) | ||||||
|  | 				if err != nil { | ||||||
|  | 					log.Errorf("Failed to send telegram message. Error %v", err) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -96,7 +96,7 @@ Users Settings | ||||||
|         $.ajax({ |         $.ajax({ | ||||||
|             cache: false, |             cache: false, | ||||||
|             method: 'GET', |             method: 'GET', | ||||||
|             url: '{{.basePath}}/getusers', |             url: '{{.basePath}}/get-users', | ||||||
|             dataType: 'json', |             dataType: 'json', | ||||||
|             contentType: "application/json", |             contentType: "application/json", | ||||||
|             success: function (data) { |             success: function (data) { | ||||||
|  |  | ||||||
|  | @ -3,5 +3,5 @@ package util | ||||||
| import "sync" | import "sync" | ||||||
| 
 | 
 | ||||||
| var IPToSubnetRange = map[string]uint16{} | var IPToSubnetRange = map[string]uint16{} | ||||||
| var TgUseridToClientID = map[int64]([]string){} | var TgUseridToClientID = map[int64][]string{} | ||||||
| var TgUseridToClientIDMutex sync.RWMutex | var TgUseridToClientIDMutex sync.RWMutex | ||||||
|  |  | ||||||
|  | @ -2,6 +2,7 @@ package util | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
|  | @ -20,7 +21,7 @@ func VerifyHash(base64Hash string, plaintext string) (bool, error) { | ||||||
| 		return false, fmt.Errorf("cannot decode base64 hash: %w", err) | 		return false, fmt.Errorf("cannot decode base64 hash: %w", err) | ||||||
| 	} | 	} | ||||||
| 	err = bcrypt.CompareHashAndPassword(hash, []byte(plaintext)) | 	err = bcrypt.CompareHashAndPassword(hash, []byte(plaintext)) | ||||||
| 	if err == bcrypt.ErrMismatchedHashAndPassword { | 	if errors.Is(err, bcrypt.ErrMismatchedHashAndPassword) { | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
							
								
								
									
										35
									
								
								util/util.go
								
								
								
								
							
							
						
						
									
										35
									
								
								util/util.go
								
								
								
								
							|  | @ -7,7 +7,6 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/fs" | 	"io/fs" | ||||||
| 	"io/ioutil" |  | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"net" | 	"net" | ||||||
| 	"os" | 	"os" | ||||||
|  | @ -189,7 +188,7 @@ func GetInterfaceIPs() ([]model.Interface, error) { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	var interfaceList = []model.Interface{} | 	var interfaceList []model.Interface | ||||||
| 
 | 
 | ||||||
| 	// get interface's ip addresses
 | 	// get interface's ip addresses
 | ||||||
| 	for _, i := range ifaces { | 	for _, i := range ifaces { | ||||||
|  | @ -230,9 +229,9 @@ func GetPublicIP() (model.Interface, error) { | ||||||
| 	consensus := externalip.NewConsensus(&cfg, nil) | 	consensus := externalip.NewConsensus(&cfg, nil) | ||||||
| 
 | 
 | ||||||
| 	// add trusted voters
 | 	// add trusted voters
 | ||||||
| 	consensus.AddVoter(externalip.NewHTTPSource("http://checkip.amazonaws.com/"), 1) | 	consensus.AddVoter(externalip.NewHTTPSource("https://checkip.amazonaws.com/"), 1) | ||||||
| 	consensus.AddVoter(externalip.NewHTTPSource("http://whatismyip.akamai.com"), 1) | 	consensus.AddVoter(externalip.NewHTTPSource("http://whatismyip.akamai.com"), 1) | ||||||
| 	consensus.AddVoter(externalip.NewHTTPSource("http://ifconfig.top"), 1) | 	consensus.AddVoter(externalip.NewHTTPSource("https://ifconfig.top"), 1) | ||||||
| 
 | 
 | ||||||
| 	publicInterface := model.Interface{} | 	publicInterface := model.Interface{} | ||||||
| 	publicInterface.Name = "Public Address" | 	publicInterface.Name = "Public Address" | ||||||
|  | @ -244,7 +243,7 @@ func GetPublicIP() (model.Interface, error) { | ||||||
| 		publicInterface.IPAddress = ip.String() | 		publicInterface.IPAddress = ip.String() | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// error handling happend above, no need to pass it through
 | 	// error handling happened above, no need to pass it through
 | ||||||
| 	return publicInterface, nil | 	return publicInterface, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -292,7 +291,7 @@ func GetAllocatedIPs(ignoreClientID string) ([]string, error) { | ||||||
| 	// append client's addresses to the result
 | 	// append client's addresses to the result
 | ||||||
| 	for _, f := range records { | 	for _, f := range records { | ||||||
| 		client := model.Client{} | 		client := model.Client{} | ||||||
| 		if err := json.Unmarshal([]byte(f), &client); err != nil { | 		if err := json.Unmarshal(f, &client); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -336,15 +335,15 @@ func GetBroadcastIP(n *net.IPNet) net.IP { | ||||||
| 
 | 
 | ||||||
| // GetBroadcastAndNetworkAddrsLookup get the ip address that can't be used with current server interfaces
 | // GetBroadcastAndNetworkAddrsLookup get the ip address that can't be used with current server interfaces
 | ||||||
| func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]bool { | func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]bool { | ||||||
| 	list := make(map[string]bool, 0) | 	list := make(map[string]bool) | ||||||
| 	for _, ifa := range interfaceAddresses { | 	for _, ifa := range interfaceAddresses { | ||||||
| 		_, net, err := net.ParseCIDR(ifa) | 		_, netAddr, err := net.ParseCIDR(ifa) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		broadcastAddr := GetBroadcastIP(net).String() | 		broadcastAddr := GetBroadcastIP(netAddr).String() | ||||||
| 		networkAddr := net.IP.String() | 		networkAddr := netAddr.IP.String() | ||||||
| 		list[broadcastAddr] = true | 		list[broadcastAddr] = true | ||||||
| 		list[networkAddr] = true | 		list[networkAddr] = true | ||||||
| 	} | 	} | ||||||
|  | @ -354,14 +353,14 @@ func GetBroadcastAndNetworkAddrsLookup(interfaceAddresses []string) map[string]b | ||||||
| // GetAvailableIP get the ip address that can be allocated from an CIDR
 | // GetAvailableIP get the ip address that can be allocated from an CIDR
 | ||||||
| // We need interfaceAddresses to find real broadcast and network addresses
 | // We need interfaceAddresses to find real broadcast and network addresses
 | ||||||
| func GetAvailableIP(cidr string, allocatedList, interfaceAddresses []string) (string, error) { | func GetAvailableIP(cidr string, allocatedList, interfaceAddresses []string) (string, error) { | ||||||
| 	ip, net, err := net.ParseCIDR(cidr) | 	ip, netAddr, err := net.ParseCIDR(cidr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	unavailableIPs := GetBroadcastAndNetworkAddrsLookup(interfaceAddresses) | 	unavailableIPs := GetBroadcastAndNetworkAddrsLookup(interfaceAddresses) | ||||||
| 
 | 
 | ||||||
| 	for ip := ip.Mask(net.Mask); net.Contains(ip); inc(ip) { | 	for ip := ip.Mask(netAddr.Mask); netAddr.Contains(ip); inc(ip) { | ||||||
| 		available := true | 		available := true | ||||||
| 		suggestedAddr := ip.String() | 		suggestedAddr := ip.String() | ||||||
| 		for _, allocatedAddr := range allocatedList { | 		for _, allocatedAddr := range allocatedList { | ||||||
|  | @ -386,7 +385,7 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip | ||||||
| 
 | 
 | ||||||
| 		// clientCIDR must be in CIDR format
 | 		// clientCIDR must be in CIDR format
 | ||||||
| 		if ip == nil { | 		if ip == nil { | ||||||
| 			return false, fmt.Errorf("Invalid ip allocation input %s. Must be in CIDR format", clientCIDR) | 			return false, fmt.Errorf("invalid ip allocation input %s. Must be in CIDR format", clientCIDR) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// return false immediately if the ip is already in use (in ipAllocatedList)
 | 		// return false immediately if the ip is already in use (in ipAllocatedList)
 | ||||||
|  | @ -398,7 +397,7 @@ func ValidateIPAllocation(serverAddresses []string, ipAllocatedList []string, ip | ||||||
| 
 | 
 | ||||||
| 		// even if it is not in use, we still need to check if it
 | 		// even if it is not in use, we still need to check if it
 | ||||||
| 		// belongs to a network of the server.
 | 		// belongs to a network of the server.
 | ||||||
| 		var isValid bool = false | 		var isValid = false | ||||||
| 		for _, serverCIDR := range serverAddresses { | 		for _, serverCIDR := range serverAddresses { | ||||||
| 			_, serverNet, _ := net.ParseCIDR(serverCIDR) | 			_, serverNet, _ := net.ParseCIDR(serverCIDR) | ||||||
| 			if serverNet.Contains(ip) { | 			if serverNet.Contains(ip) { | ||||||
|  | @ -437,7 +436,7 @@ func findSubnetRangeForIP(cidr string) (uint16, error) { | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return 0, fmt.Errorf("Subnet range not found for this IP") | 	return 0, fmt.Errorf("subnet range not found for this IP") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // FillClientSubnetRange to fill subnet ranges client belongs to, does nothing if SRs are not found
 | // FillClientSubnetRange to fill subnet ranges client belongs to, does nothing if SRs are not found
 | ||||||
|  | @ -470,11 +469,11 @@ func ValidateAndFixSubnetRanges(db store.IStore) error { | ||||||
| 	var serverSubnets []*net.IPNet | 	var serverSubnets []*net.IPNet | ||||||
| 	for _, addr := range server.Interface.Addresses { | 	for _, addr := range server.Interface.Addresses { | ||||||
| 		addr = strings.TrimSpace(addr) | 		addr = strings.TrimSpace(addr) | ||||||
| 		_, net, err := net.ParseCIDR(addr) | 		_, netAddr, err := net.ParseCIDR(addr) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| 		serverSubnets = append(serverSubnets, net) | 		serverSubnets = append(serverSubnets, netAddr) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, rng := range SubnetRangesOrder { | 	for _, rng := range SubnetRangesOrder { | ||||||
|  | @ -544,7 +543,7 @@ func WriteWireGuardServerConfig(tmplDir fs.FS, serverConfig model.Server, client | ||||||
| 
 | 
 | ||||||
| 	// if set, read wg.conf template from WgConfTemplate
 | 	// if set, read wg.conf template from WgConfTemplate
 | ||||||
| 	if len(WgConfTemplate) > 0 { | 	if len(WgConfTemplate) > 0 { | ||||||
| 		fileContentBytes, err := ioutil.ReadFile(WgConfTemplate) | 		fileContentBytes, err := os.ReadFile(WgConfTemplate) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue