Merge b593512946 into 2fdafd34ca
				
					
				
			This commit is contained in:
		
						commit
						b31729a420
					
				|  | @ -8,7 +8,11 @@ import ( | |||
| 	"github.com/gorilla/sessions" | ||||
| 	"github.com/labstack/echo-contrib/session" | ||||
| 	"github.com/labstack/echo/v4" | ||||
| 	"github.com/labstack/gommon/log" | ||||
| 	"github.com/ngoduykhanh/wireguard-ui/model" | ||||
| 	"github.com/ngoduykhanh/wireguard-ui/store/jsondb" | ||||
| 	"github.com/ngoduykhanh/wireguard-ui/util" | ||||
| 	"github.com/rs/xid" | ||||
| ) | ||||
| 
 | ||||
| func ValidSession(next echo.HandlerFunc) echo.HandlerFunc { | ||||
|  | @ -43,6 +47,86 @@ func NeedsAdmin(next echo.HandlerFunc) echo.HandlerFunc { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // SSOauth uses external authentication (usually by reverseproxy) in the form of HTTP header REMOTE_USER
 | ||||
| func SSOauth(next echo.HandlerFunc) echo.HandlerFunc { | ||||
| 	return func(c echo.Context) error { | ||||
| 		if !util.RemoteUser { | ||||
| 			return next(c) | ||||
| 		} | ||||
| 		if !isValidSession(c) { | ||||
| 			remoteUser := c.Request().Header.Get("REMOTE_USER") | ||||
| 			if remoteUser == "" { | ||||
| 				// TODO: Better error handling
 | ||||
| 				log.Infof("No REMOTE_USER in reqest. Bailing out.") | ||||
| 				return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") | ||||
| 			} | ||||
| 			log.Debugf("No valid session for REMOTE_USER: %s", remoteUser) | ||||
| 
 | ||||
| 			db := c.Get("db").(*jsondb.JsonDB) | ||||
| 			dbuser, err := db.GetUserByName(remoteUser) | ||||
| 			if err != nil { | ||||
| 				log.Infof("User %s not in database, creating user", remoteUser) | ||||
| 				newUser := model.User{ | ||||
| 					Username: remoteUser, | ||||
| 					Admin:    false, | ||||
| 				} | ||||
| 				err = db.SaveUser(newUser) | ||||
| 				if err != nil { | ||||
| 					// TODO: Better error handling
 | ||||
| 					return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") | ||||
| 				} | ||||
| 				// Update dbuser from database
 | ||||
| 				dbuser, err = db.GetUserByName(remoteUser) | ||||
| 				if err != nil { | ||||
| 					// TODO: Better error handling
 | ||||
| 					return c.Redirect(http.StatusTemporaryRedirect, util.BasePath+"/") | ||||
| 				} | ||||
| 
 | ||||
| 			} else { | ||||
| 				log.Debugf("Got user from db: %s", dbuser.Username) | ||||
| 			} | ||||
| 
 | ||||
| 			// Set session for REMOTE_USER
 | ||||
| 			ageMax := 0 | ||||
| 
 | ||||
| 			cookiePath := util.GetCookiePath() | ||||
| 
 | ||||
| 			sess, _ := session.Get("session", c) | ||||
| 			sess.Options = &sessions.Options{ | ||||
| 				Path:     cookiePath, | ||||
| 				MaxAge:   ageMax, | ||||
| 				HttpOnly: true, | ||||
| 				SameSite: http.SameSiteLaxMode, | ||||
| 			} | ||||
| 
 | ||||
| 			// set session_token
 | ||||
| 			tokenUID := xid.New().String() | ||||
| 			now := time.Now().UTC().Unix() | ||||
| 			sess.Values["username"] = dbuser.Username | ||||
| 			sess.Values["user_hash"] = util.GetDBUserCRC32(dbuser) | ||||
| 			sess.Values["admin"] = dbuser.Admin | ||||
| 			sess.Values["session_token"] = tokenUID | ||||
| 			sess.Values["max_age"] = ageMax | ||||
| 			sess.Values["created_at"] = now | ||||
| 			sess.Values["updated_at"] = now | ||||
| 			sess.Save(c.Request(), c.Response()) | ||||
| 
 | ||||
| 			// set session_token in cookie
 | ||||
| 			cookie := new(http.Cookie) | ||||
| 			cookie.Name = "session_token" | ||||
| 			cookie.Path = cookiePath | ||||
| 			cookie.Value = tokenUID | ||||
| 			cookie.MaxAge = ageMax | ||||
| 			cookie.HttpOnly = true | ||||
| 			cookie.SameSite = http.SameSiteLaxMode | ||||
| 			c.SetCookie(cookie) | ||||
| 
 | ||||
| 			return c.Redirect(http.StatusTemporaryRedirect, util.BasePath) | ||||
| 		} | ||||
| 		return next(c) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func isValidSession(c echo.Context) bool { | ||||
| 	if util.DisableLogin { | ||||
| 		return true | ||||
|  |  | |||
							
								
								
									
										8
									
								
								main.go
								
								
								
								
							
							
						
						
									
										8
									
								
								main.go
								
								
								
								
							|  | @ -33,6 +33,7 @@ var ( | |||
| 	buildTime  = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05")) | ||||
| 	// configuration variables
 | ||||
| 	flagDisableLogin             = false | ||||
| 	flagRemoteUser               = false | ||||
| 	flagBindAddress              = "0.0.0.0:5000" | ||||
| 	flagSmtpHostname             = "127.0.0.1" | ||||
| 	flagSmtpPort                 = 25 | ||||
|  | @ -77,6 +78,7 @@ var embeddedAssets embed.FS | |||
| func init() { | ||||
| 	// command-line flags and env variables
 | ||||
| 	flag.BoolVar(&flagDisableLogin, "disable-login", util.LookupEnvOrBool("DISABLE_LOGIN", flagDisableLogin), "Disable authentication on the app. This is potentially dangerous.") | ||||
| 	flag.BoolVar(&flagRemoteUser, "remote_user", util.LookupEnvOrBool("REMOTE_USER", flagRemoteUser), "Use HTTP header REMOTE_USER for auth. Commonly used with SSO and a proxy funcion.") | ||||
| 	flag.StringVar(&flagBindAddress, "bind-address", util.LookupEnvOrString("BIND_ADDRESS", flagBindAddress), "Address:Port to which the app will be bound.") | ||||
| 	flag.StringVar(&flagSmtpHostname, "smtp-hostname", util.LookupEnvOrString("SMTP_HOSTNAME", flagSmtpHostname), "SMTP Hostname") | ||||
| 	flag.IntVar(&flagSmtpPort, "smtp-port", util.LookupEnvOrInt("SMTP_PORT", flagSmtpPort), "SMTP Port") | ||||
|  | @ -126,6 +128,7 @@ func init() { | |||
| 
 | ||||
| 	// update runtime config
 | ||||
| 	util.DisableLogin = flagDisableLogin | ||||
| 	util.RemoteUser = flagRemoteUser | ||||
| 	util.BindAddress = flagBindAddress | ||||
| 	util.SmtpHostname = flagSmtpHostname | ||||
| 	util.SmtpPort = flagSmtpPort | ||||
|  | @ -161,6 +164,7 @@ func init() { | |||
| 		fmt.Println("Build Time\t:", buildTime) | ||||
| 		fmt.Println("Git Repo\t:", "https://github.com/ngoduykhanh/wireguard-ui") | ||||
| 		fmt.Println("Authentication\t:", !util.DisableLogin) | ||||
| 		fmt.Println("Remote_user\t:", util.RemoteUser) | ||||
| 		fmt.Println("Bind address\t:", util.BindAddress) | ||||
| 		//fmt.Println("Sendgrid key\t:", util.SendgridApiKey)
 | ||||
| 		fmt.Println("Email from\t:", util.EmailFrom) | ||||
|  | @ -206,9 +210,9 @@ func main() { | |||
| 	} | ||||
| 
 | ||||
| 	// register routes
 | ||||
| 	app := router.New(tmplDir, extraData, util.SessionSecret) | ||||
| 	app := router.New(tmplDir, extraData, util.SessionSecret, db) | ||||
| 
 | ||||
| 	app.GET(util.BasePath, handler.WireGuardClients(db), handler.ValidSession, handler.RefreshSession) | ||||
| 	app.GET(util.BasePath, handler.WireGuardClients(db), handler.SSOauth, handler.ValidSession, handler.RefreshSession) | ||||
| 
 | ||||
| 	// Important: Make sure that all non-GET routes check the request content type using handler.ContentTypeJson to
 | ||||
| 	// mitigate CSRF attacks. This is effective, because browsers don't allow setting the Content-Type header on
 | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ import ( | |||
| 	"github.com/labstack/echo/v4" | ||||
| 	"github.com/labstack/echo/v4/middleware" | ||||
| 	"github.com/labstack/gommon/log" | ||||
| 	"github.com/ngoduykhanh/wireguard-ui/store/jsondb" | ||||
| 	"github.com/ngoduykhanh/wireguard-ui/util" | ||||
| ) | ||||
| 
 | ||||
|  | @ -48,7 +49,7 @@ func (t *TemplateRegistry) Render(w io.Writer, name string, data interface{}, c | |||
| } | ||||
| 
 | ||||
| // New function
 | ||||
| func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo.Echo { | ||||
| func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte, db *jsondb.JsonDB) *echo.Echo { | ||||
| 	e := echo.New() | ||||
| 
 | ||||
| 	cookiePath := util.GetCookiePath() | ||||
|  | @ -60,6 +61,14 @@ func New(tmplDir fs.FS, extraData map[string]interface{}, secret [64]byte) *echo | |||
| 
 | ||||
| 	e.Use(session.Middleware(cookieStore)) | ||||
| 
 | ||||
| 	// Add db to context so middlewares can use it.
 | ||||
| 	e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { | ||||
| 		return func(c echo.Context) error { | ||||
| 			c.Set("db", db) | ||||
| 			return next(c) | ||||
| 		} | ||||
| 	}) | ||||
| 
 | ||||
| 	// read html template file to string
 | ||||
| 	tmplBaseString, err := util.StringFromEmbedFile(tmplDir, "base.html") | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ import ( | |||
| // Runtime config
 | ||||
| var ( | ||||
| 	DisableLogin       bool | ||||
| 	RemoteUser         bool | ||||
| 	BindAddress        string | ||||
| 	SmtpHostname       string | ||||
| 	SmtpPort           int | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue