feat: add GitHub OAuth handlers and wiring

This commit is contained in:
devcxl 2026-04-04 07:34:52 +08:00
parent bdbb9250b8
commit 97f6ada0e0
5 changed files with 1292 additions and 14 deletions

5
go.mod
View File

@ -1,10 +1,9 @@
module github.com/ngoduykhanh/wireguard-ui
go 1.21
go 1.25.0
require (
github.com/glendc/go-external-ip v0.1.0
github.com/gorilla/sessions v1.2.2
github.com/labstack/echo-contrib v0.15.0
github.com/labstack/echo/v4 v4.11.4
@ -22,6 +21,8 @@ require (
gopkg.in/go-playground/validator.v9 v9.31.0
)
require golang.org/x/oauth2 v0.36.0
require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect

4
go.sum
View File

@ -1,5 +1,3 @@
github.com/NicoNex/echotron/v3 v3.27.0 h1:iq4BLPO+Dz1JHjh2HPk0D0NldAZSYcAjaOicgYEhUzw=
github.com/NicoNex/echotron/v3 v3.27.0/go.mod h1:LpP5IyHw0y+DZUZMBgXEDAF9O8feXrQu7w7nlJzzoZI=
github.com/coreos/bbolt v1.3.1-coreos.6.0.20180223184059-4f5275f4ebbf/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -132,6 +130,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

View File

@ -0,0 +1,364 @@
package handler
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"regexp"
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"golang.org/x/oauth2"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/store"
"github.com/ngoduykhanh/wireguard-ui/util"
)
var (
GitHubOAuth2Config *oauth2.Config
GitHubAllowedUsers []string
GitHubAllowedOrgs []string
GitHubAdminUsers []string
GitHubUserInfoURL = "https://api.github.com/user"
GitHubOrgsURL = "https://api.github.com/user/memberships/orgs"
)
var githubLoginRegexp = regexp.MustCompile(`^\w[\w\-.]*$`)
var errGitHubUsernameConflict = errors.New("username is already bound to a different GitHub account")
func ApplyGitHubAuthConfig(config util.GitHubAuthConfig) error {
clientSecret := strings.TrimSpace(config.ClientSecret)
if clientSecret == "" && config.ClientSecretFile != "" {
secret, err := os.ReadFile(config.ClientSecretFile)
if err != nil {
return err
}
clientSecret = strings.TrimSpace(string(secret))
}
GitHubOAuth2Config = &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: clientSecret,
RedirectURL: config.RedirectURL,
Scopes: []string{"read:user", "read:org"},
Endpoint: oauth2.Endpoint{
AuthURL: "https://github.com/login/oauth/authorize",
TokenURL: "https://github.com/login/oauth/access_token",
},
}
GitHubAllowedUsers = append([]string(nil), config.AllowedUsers...)
GitHubAllowedOrgs = append([]string(nil), config.AllowedOrgs...)
GitHubAdminUsers = append([]string(nil), config.AdminUsers...)
return nil
}
type githubUserResponse struct {
Login string `json:"login"`
ID int64 `json:"id"`
Name string `json:"name"`
}
type githubOrgMembership struct {
Organization struct {
Login string `json:"login"`
} `json:"organization"`
State string `json:"state"`
}
func GitHubStart() echo.HandlerFunc {
return func(c echo.Context) error {
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
next := c.QueryParam("next")
if next == "" {
next = "/"
}
if !isSafeNextURL(next) {
next = "/"
}
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = next
sess.Save(c.Request(), c.Response())
authURL := GitHubOAuth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
return c.Redirect(http.StatusTemporaryRedirect, authURL)
}
}
func GitHubCallback(db store.IStore) echo.HandlerFunc {
return func(c echo.Context) error {
sess, _ := session.Get("session", c)
state := c.QueryParam("state")
expectedState, _ := sess.Values["oauth_state"].(string)
if state == "" || expectedState == "" || state != expectedState {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "invalid oauth state"})
}
expiresAt, _ := sess.Values["oauth_state_expires_at"].(int64)
if time.Now().UTC().Unix() > expiresAt {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "oauth state expired"})
}
next, _ := sess.Values["oauth_next"].(string)
if !isSafeNextURL(next) {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "invalid redirect target"})
}
code := c.QueryParam("code")
if code == "" {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "missing authorization code"})
}
token, err := GitHubOAuth2Config.Exchange(c.Request().Context(), code)
if err != nil {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "failed to exchange code for token"})
}
githubUser, err := fetchGitHubUser(token.AccessToken)
if err != nil {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "failed to fetch GitHub user"})
}
githubLogin := strings.ToLower(githubUser.Login)
if !isValidGitHubLogin(githubLogin) {
clearOAuthSession(c)
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "invalid GitHub username"})
}
isAllowed := isGitHubUserAllowed(githubLogin, token.AccessToken)
if !isAllowed {
clearOAuthSession(c)
return c.JSON(http.StatusForbidden, jsonHTTPResponse{false, "user not authorized"})
}
isAdmin := isGitHubAdmin(githubLogin)
user, originalUsername, err := findOrMigrateGitHubUser(db, githubLogin, githubUser.ID, githubUser.Name)
if err != nil {
clearOAuthSession(c)
if errors.Is(err, errGitHubUsernameConflict) {
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, err.Error()})
}
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to resolve GitHub user"})
}
user.Username = githubLogin
user.DisplayName = githubUser.Name
user.AuthSource = "github"
user.AuthSubject = fmt.Sprintf("%d", githubUser.ID)
user.Password = ""
user.PasswordHash = ""
user.Admin = isAdmin
if originalUsername != "" && originalUsername != githubLogin {
if err := db.ReplaceUser(originalUsername, *user); err != nil {
clearOAuthSession(c)
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to update user"})
}
} else {
if err := db.SaveUser(*user); err != nil {
clearOAuthSession(c)
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to save user"})
}
}
rememberMe := getSessionRememberMe(sess)
if err := establishAuthenticatedSession(c, *user, rememberMe); err != nil {
clearOAuthSession(c)
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to establish session"})
}
clearOAuthSession(c)
return c.Redirect(http.StatusTemporaryRedirect, next)
}
}
func generateOAuthState() string {
b := make([]byte, 32)
_, _ = rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
func isSafeNextURL(next string) bool {
if next == "" {
return false
}
u, err := url.Parse(next)
if err != nil {
return false
}
if u.Opaque != "" && !strings.HasPrefix(next, "/") {
return false
}
if u.Host != "" {
return false
}
if !strings.HasPrefix(next, "/") {
return false
}
return true
}
func isValidGitHubLogin(login string) bool {
if login == "" {
return false
}
return githubLoginRegexp.MatchString(login)
}
func isGitHubUserAllowed(login string, accessToken string) bool {
loginLower := strings.ToLower(login)
if len(GitHubAllowedUsers) > 0 {
for _, allowed := range GitHubAllowedUsers {
if strings.ToLower(allowed) == loginLower {
return true
}
}
}
if len(GitHubAllowedOrgs) > 0 {
orgs, err := fetchGitHubOrgs(accessToken)
if err == nil {
for _, org := range orgs {
for _, allowedOrg := range GitHubAllowedOrgs {
if strings.ToLower(org) == strings.ToLower(allowedOrg) {
return true
}
}
}
}
}
return false
}
func isGitHubAdmin(login string) bool {
loginLower := strings.ToLower(login)
for _, admin := range GitHubAdminUsers {
if strings.ToLower(admin) == loginLower {
return true
}
}
return false
}
func fetchGitHubUser(accessToken string) (*githubUserResponse, error) {
req, err := http.NewRequest(http.MethodGet, GitHubUserInfoURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("github api returned status %d", resp.StatusCode)
}
var user githubUserResponse
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, err
}
return &user, nil
}
func fetchGitHubOrgs(accessToken string) ([]string, error) {
req, err := http.NewRequest(http.MethodGet, GitHubOrgsURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("github api returned status %d", resp.StatusCode)
}
var memberships []githubOrgMembership
if err := json.NewDecoder(resp.Body).Decode(&memberships); err != nil {
return nil, err
}
var orgs []string
for _, m := range memberships {
if m.State == "active" {
orgs = append(orgs, strings.ToLower(m.Organization.Login))
}
}
return orgs, nil
}
func clearOAuthSession(c echo.Context) {
sess, _ := session.Get("session", c)
delete(sess.Values, "oauth_state")
delete(sess.Values, "oauth_state_expires_at")
delete(sess.Values, "oauth_next")
sess.Save(c.Request(), c.Response())
}
func getSessionRememberMe(sess *sessions.Session) bool {
maxAge, ok := sess.Values["max_age"].(int)
if !ok {
return false
}
return maxAge > 0
}
func findOrMigrateGitHubUser(db store.IStore, githubLogin string, githubID int64, displayName string) (*model.User, string, error) {
existingUser, err := db.GetUserByAuthIdentity("github", fmt.Sprintf("%d", githubID))
if err == nil {
return &existingUser, existingUser.Username, nil
}
localUser, err := db.GetUserByName(githubLogin)
if err == nil {
if localUser.AuthSource == "local" && localUser.AuthSubject == "" {
return &localUser, localUser.Username, nil
}
if localUser.AuthSource == "github" && localUser.AuthSubject != "" && localUser.AuthSubject != fmt.Sprintf("%d", githubID) {
return nil, "", errGitHubUsernameConflict
}
return &localUser, localUser.Username, nil
}
newUser := model.User{
Username: githubLogin,
DisplayName: displayName,
}
return &newUser, "", nil
}

View File

@ -0,0 +1,892 @@
package handler
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"hash/crc32"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
"github.com/ngoduykhanh/wireguard-ui/model"
"github.com/ngoduykhanh/wireguard-ui/util"
"golang.org/x/oauth2"
)
var (
githubAPIHandler http.Handler
githubTokenURL string
githubUserURL string
githubOrgsURL string
)
func init() {
util.SessionMaxDuration = 86400 * 90
util.DBUsersToCRC32 = map[string]uint32{}
}
func getTestUserCRC32(user model.User) uint32 {
h := crc32.NewIEEE()
writeHashField := func(h io.Writer, value string) {
var length [4]byte
binary.BigEndian.PutUint32(length[:], uint32(len(value)))
_, _ = h.Write(length[:])
_, _ = h.Write([]byte(value))
}
writeHashField(h, user.Username)
if user.Admin {
writeHashField(h, "1")
} else {
writeHashField(h, "0")
}
writeHashField(h, user.Password)
writeHashField(h, user.PasswordHash)
writeHashField(h, user.AuthSource)
writeHashField(h, user.AuthSubject)
return h.Sum32()
}
func setupTestServer() (*httptest.Server, *http.ServeMux) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
githubTokenURL = server.URL + "/oauth/token"
githubUserURL = server.URL + "/user"
githubOrgsURL = server.URL + "/user/memberships/orgs"
return server, mux
}
func setupGitHubAPIHandler(mux *http.ServeMux, userLogin string, userID int64, orgs []string, isOrgMember bool) {
mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"access_token": "test-access-token",
"token_type": "Bearer",
"scope": "read:user,user:email",
})
})
mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader != "Bearer test-access-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"login": userLogin,
"id": userID,
"name": "Test User",
})
})
mux.HandleFunc("/user/memberships/orgs", func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader != "Bearer test-access-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
if isOrgMember {
var orgsData []map[string]interface{}
for _, org := range orgs {
orgsData = append(orgsData, map[string]interface{}{
"organization": map[string]interface{}{
"login": org,
},
"state": "active",
})
}
json.NewEncoder(w).Encode(orgsData)
} else {
json.NewEncoder(w).Encode([]map[string]interface{}{})
}
})
}
func TestApplyGitHubAuthConfig_SetsRuntimeGlobals(t *testing.T) {
tmpDir := t.TempDir()
secretFile := filepath.Join(tmpDir, "github-secret")
if err := os.WriteFile(secretFile, []byte("test-secret"), 0600); err != nil {
t.Fatalf("failed to write secret file: %v", err)
}
config := util.GitHubAuthConfig{
ClientID: "client-id",
ClientSecretFile: secretFile,
RedirectURL: "https://example.com/auth/github/callback",
AllowedUsers: []string{"user1"},
AllowedOrgs: []string{"org1"},
AdminUsers: []string{"admin1"},
}
if err := ApplyGitHubAuthConfig(config); err != nil {
t.Fatalf("expected ApplyGitHubAuthConfig to succeed, got %v", err)
}
if GitHubOAuth2Config == nil {
t.Fatal("expected GitHubOAuth2Config to be initialized")
}
if GitHubOAuth2Config.ClientID != "client-id" {
t.Fatalf("expected client id to be bridged, got %q", GitHubOAuth2Config.ClientID)
}
if GitHubOAuth2Config.ClientSecret != "test-secret" {
t.Fatalf("expected client secret to be loaded from file, got %q", GitHubOAuth2Config.ClientSecret)
}
if GitHubOAuth2Config.RedirectURL != "https://example.com/auth/github/callback" {
t.Fatalf("expected redirect url to be bridged, got %q", GitHubOAuth2Config.RedirectURL)
}
if len(GitHubAllowedUsers) != 1 || GitHubAllowedUsers[0] != "user1" {
t.Fatalf("expected allowed users to be bridged, got %#v", GitHubAllowedUsers)
}
if len(GitHubAllowedOrgs) != 1 || GitHubAllowedOrgs[0] != "org1" {
t.Fatalf("expected allowed orgs to be bridged, got %#v", GitHubAllowedOrgs)
}
if len(GitHubAdminUsers) != 1 || GitHubAdminUsers[0] != "admin1" {
t.Fatalf("expected admin users to be bridged, got %#v", GitHubAdminUsers)
}
if GitHubOAuth2Config.Endpoint.AuthURL == "" || GitHubOAuth2Config.Endpoint.TokenURL == "" {
t.Fatal("expected GitHub OAuth endpoints to be initialized")
}
if len(GitHubOAuth2Config.Scopes) != 2 || GitHubOAuth2Config.Scopes[0] != "read:user" || GitHubOAuth2Config.Scopes[1] != "read:org" {
t.Fatalf("expected GitHub OAuth scopes to be initialized, got %#v", GitHubOAuth2Config.Scopes)
}
}
type testApp struct {
E *echo.Echo
CookieStore *sessions.CookieStore
}
func newTestApp() *testApp {
e := echo.New()
secret := make([]byte, 64)
rand.Read(secret)
cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
cookieStore.Options.Path = "/"
cookieStore.Options.HttpOnly = true
cookieStore.MaxAge(86400 * 7)
e.Use(session.Middleware(cookieStore))
return &testApp{E: e, CookieStore: cookieStore}
}
func newTestAppWithFixedSecret() *testApp {
e := echo.New()
secret, _ := hex.DecodeString("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
cookieStore.Options.Path = "/"
cookieStore.Options.HttpOnly = true
cookieStore.MaxAge(86400 * 7)
e.Use(session.Middleware(cookieStore))
return &testApp{E: e, CookieStore: cookieStore}
}
type mockStore struct {
users []model.User
savedUsers map[string]model.User
deletedUsers []string
}
func (m *mockStore) Init() error { return nil }
func (m *mockStore) GetUsers() ([]model.User, error) { return m.users, nil }
func (m *mockStore) GetUserByName(username string) (model.User, error) {
for _, u := range m.users {
if u.Username == username {
return u, nil
}
}
return model.User{}, fmt.Errorf("not found")
}
func (m *mockStore) GetUserByAuthIdentity(authSource, authSubject string) (model.User, error) {
for _, u := range m.users {
if u.AuthSource == authSource && u.AuthSubject == authSubject {
return u, nil
}
}
return model.User{}, fmt.Errorf("not found")
}
func (m *mockStore) SaveUser(user model.User) error {
if m.savedUsers == nil {
m.savedUsers = make(map[string]model.User)
}
m.savedUsers[user.Username] = user
for i, u := range m.users {
if u.Username == user.Username {
m.users[i] = user
return nil
}
}
m.users = append(m.users, user)
return nil
}
func (m *mockStore) ReplaceUser(oldUsername string, user model.User) error {
return m.SaveUser(user)
}
func (m *mockStore) DeleteUser(username string) error {
m.deletedUsers = append(m.deletedUsers, username)
return nil
}
func (m *mockStore) GetGlobalSettings() (model.GlobalSetting, error) {
return model.GlobalSetting{}, nil
}
func (m *mockStore) GetServer() (model.Server, error) { return model.Server{}, nil }
func (m *mockStore) GetClients(bool) ([]model.ClientData, error) { return nil, nil }
func (m *mockStore) GetClientByID(string, model.QRCodeSettings) (model.ClientData, error) {
return model.ClientData{}, nil
}
func (m *mockStore) SaveClient(model.Client) error { return nil }
func (m *mockStore) DeleteClient(string) error { return nil }
func (m *mockStore) SaveServerInterface(model.ServerInterface) error { return nil }
func (m *mockStore) SaveServerKeyPair(model.ServerKeypair) error { return nil }
func (m *mockStore) SaveGlobalSettings(model.GlobalSetting) error { return nil }
func (m *mockStore) GetWakeOnLanHosts() ([]model.WakeOnLanHost, error) { return nil, nil }
func (m *mockStore) GetWakeOnLanHost(string) (*model.WakeOnLanHost, error) { return nil, nil }
func (m *mockStore) DeleteWakeOnHostLanHost(string) error { return nil }
func (m *mockStore) SaveWakeOnLanHost(model.WakeOnLanHost) error { return nil }
func (m *mockStore) DeleteWakeOnHost(model.WakeOnLanHost) error { return nil }
func (m *mockStore) GetPath() string { return "" }
func (m *mockStore) SaveHashes(model.ClientServerHashes) error { return nil }
func (m *mockStore) GetHashes() (model.ClientServerHashes, error) {
return model.ClientServerHashes{}, nil
}
func TestGitHubCallback_AllowsWhitelistedUser(t *testing.T) {
server, mux := setupTestServer()
defer server.Close()
setupGitHubAPIHandler(mux, "whitelistuser", 1001, nil, false)
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{"whitelistuser"}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
GitHubUserInfoURL = githubUserURL
GitHubOrgsURL = githubOrgsURL
mockDB := &mockStore{}
app := newTestApp()
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("start handler failed: %d", rec.Code)
}
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusTemporaryRedirect {
t.Fatalf("expected redirect 307, got %d. body: %s, location: %s", rec.Code, rec.Body.String(), rec.Header().Get("Location"))
}
if rec.Header().Get("Location") != "/" {
t.Errorf("expected redirect to '/', got '%s'", rec.Header().Get("Location"))
}
sessionCookies := rec.Result().Cookies()
var sessionToken string
for _, c := range sessionCookies {
if c.Name == "session_token" {
sessionToken = c.Value
break
}
}
if sessionToken == "" {
t.Fatal("expected session_token cookie to be set")
}
savedUser, ok := mockDB.savedUsers["whitelistuser"]
if !ok {
t.Fatal("expected user to be saved to store")
}
if savedUser.AuthSource != "github" {
t.Errorf("expected AuthSource 'github', got '%s'", savedUser.AuthSource)
}
if savedUser.AuthSubject != "1001" {
t.Errorf("expected AuthSubject '1001', got '%s'", savedUser.AuthSubject)
}
if savedUser.Password != "" {
t.Errorf("expected empty Password, got '%s'", savedUser.Password)
}
if savedUser.PasswordHash != "" {
t.Errorf("expected empty PasswordHash, got '%s'", savedUser.PasswordHash)
}
if savedUser.Admin {
t.Errorf("expected Admin false, got true")
}
}
func TestGitHubCallback_AllowsOrgMember(t *testing.T) {
server, mux := setupTestServer()
defer server.Close()
setupGitHubAPIHandler(mux, "orgmember", 1002, []string{"allowed-org"}, true)
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{}
GitHubAllowedOrgs = []string{"allowed-org"}
GitHubAdminUsers = []string{"admin"}
GitHubUserInfoURL = githubUserURL
GitHubOrgsURL = githubOrgsURL
mockDB := &mockStore{}
app := newTestApp()
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("start handler failed: %d", rec.Code)
}
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusTemporaryRedirect {
t.Fatalf("expected redirect 307, got %d", rec.Code)
}
}
func TestGitHubCallback_RejectsStateMismatch(t *testing.T) {
server, _ := setupTestServer()
defer server.Close()
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{"testuser"}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
mockDB := &mockStore{}
app := newTestApp()
app.E.GET("/github/callback", GitHubCallback(mockDB))
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state=wrong-state", nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rec.Code)
}
}
func TestGitHubCallback_RejectsExpiredState(t *testing.T) {
server, _ := setupTestServer()
defer server.Close()
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{"testuser"}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
mockDB := &mockStore{}
app := newTestApp()
state := generateOAuthState()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = time.Now().UTC().Add(-1 * time.Minute).Unix()
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rec.Code)
}
}
func TestGitHubCallback_RejectsExternalNextURL(t *testing.T) {
server, _ := setupTestServer()
defer server.Close()
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{"testuser"}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
mockDB := &mockStore{}
app := newTestApp()
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "http://evil.com/redirect"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rec.Code)
}
}
func TestGitHubCallback_RejectsUnsafeGitHubLogin(t *testing.T) {
unsafeLogins := []string{
"user with space",
"user\nwith\nnewline",
"",
"-invalid",
}
for _, unsafeLogin := range unsafeLogins {
server, mux := setupTestServer()
setupGitHubAPIHandler(mux, unsafeLogin, 1001, nil, false)
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{strings.ToLower(unsafeLogin)}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
GitHubUserInfoURL = githubUserURL
GitHubOrgsURL = githubOrgsURL
mockDB := &mockStore{}
app := newTestApp()
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
server.Close()
if rec.Code != http.StatusBadRequest {
t.Errorf("login %q: expected status 400, got %d", unsafeLogin, rec.Code)
}
}
}
func TestGitHubCallback_MigratesLegacyLocalUser(t *testing.T) {
server, mux := setupTestServer()
defer server.Close()
setupGitHubAPIHandler(mux, "legacyuser", 1001, nil, false)
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{"legacyuser"}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
GitHubUserInfoURL = githubUserURL
GitHubOrgsURL = githubOrgsURL
mockDB := &mockStore{
users: []model.User{
{
Username: "legacyuser",
Password: "legacy-pass",
PasswordHash: "",
AuthSource: "local",
AuthSubject: "",
Admin: false,
},
},
}
app := newTestApp()
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("start handler failed: %d", rec.Code)
}
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusTemporaryRedirect {
t.Fatalf("expected redirect 307, got %d", rec.Code)
}
savedUser, ok := mockDB.savedUsers["legacyuser"]
if !ok {
t.Fatal("expected user to be saved to store")
}
if savedUser.AuthSource != "github" {
t.Errorf("expected AuthSource 'github', got '%s'", savedUser.AuthSource)
}
if savedUser.AuthSubject != "1001" {
t.Errorf("expected AuthSubject '1001', got '%s'", savedUser.AuthSubject)
}
if savedUser.Password != "" {
t.Errorf("expected empty Password, got '%s'", savedUser.Password)
}
if savedUser.PasswordHash != "" {
t.Errorf("expected empty PasswordHash, got '%s'", savedUser.PasswordHash)
}
if savedUser.DisplayName != "Test User" {
t.Errorf("expected DisplayName 'Test User', got '%s'", savedUser.DisplayName)
}
}
func TestGitHubCallback_RejectsConflictingBoundUser(t *testing.T) {
server, mux := setupTestServer()
defer server.Close()
setupGitHubAPIHandler(mux, "anotheruser", 1002, nil, false)
GitHubOAuth2Config = &oauth2.Config{
ClientID: "test-client-id",
ClientSecret: "test-client-secret",
RedirectURL: "http://localhost:8080/github/callback",
Endpoint: oauth2.Endpoint{
AuthURL: server.URL + "/oauth/authorize",
TokenURL: githubTokenURL,
},
}
GitHubAllowedUsers = []string{"anotheruser"}
GitHubAllowedOrgs = []string{}
GitHubAdminUsers = []string{"admin"}
GitHubUserInfoURL = githubUserURL
GitHubOrgsURL = githubOrgsURL
mockDB := &mockStore{
users: []model.User{
{
Username: "anotheruser",
Password: "local-pass",
PasswordHash: "",
AuthSource: "github",
AuthSubject: "999999",
Admin: false,
},
},
}
app := newTestApp()
state := generateOAuthState()
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
app.E.POST("/github/start", func(c echo.Context) error {
sess, _ := session.Get("session", c)
sess.Values["oauth_state"] = state
sess.Values["oauth_state_expires_at"] = expiresAt
sess.Values["oauth_next"] = "/"
sess.Save(c.Request(), c.Response())
return c.NoContent(http.StatusOK)
})
app.E.GET("/github/callback", GitHubCallback(mockDB))
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
cookies := rec.Result().Cookies()
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", rec.Code)
}
if len(mockDB.savedUsers) > 0 {
t.Errorf("expected no user to be saved, got %d saved users", len(mockDB.savedUsers))
}
}
func TestLocalLoginAndGitHubLoginShareSessionBuilder(t *testing.T) {
testUser := model.User{
Username: "testuser",
Password: "testpass",
PasswordHash: "",
AuthSource: "local",
AuthSubject: "",
Admin: true,
}
util.DBUsersToCRC32["testuser"] = getTestUserCRC32(testUser)
mockDB := &mockStore{
users: []model.User{testUser},
}
app := newTestAppWithFixedSecret()
app.E.POST("/login", Login(mockDB), ContentTypeJson)
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(`{"username":"testuser","password":"testpass","rememberMe":true}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
app.E.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d. body: %s", rec.Code, rec.Body.String())
}
cookies := rec.Result().Cookies()
var sessionToken string
var sessionCookie *http.Cookie
for _, c := range cookies {
if c.Name == "session_token" {
sessionToken = c.Value
sessionCookie = c
break
}
}
if sessionToken == "" {
t.Fatal("expected session_token cookie to be set")
}
if sessionCookie == nil {
t.Fatal("expected session_token cookie to be set")
}
if sessionCookie.Value == "" {
t.Fatal("expected session_token cookie value to be non-empty")
}
if sessionCookie.MaxAge == 0 {
t.Fatal("expected session_token cookie MaxAge to be set for rememberMe=true")
}
req = httptest.NewRequest(http.MethodGet, "/profile", nil)
for _, c := range cookies {
req.AddCookie(c)
}
rec = httptest.NewRecorder()
profileCalled := false
var profileUsername string
var profileAdmin interface{}
app.E.GET("/profile", func(c echo.Context) error {
sess, _ := session.Get("session", c)
profileCalled = true
if sess != nil {
profileUsername, _ = sess.Values["username"].(string)
profileAdmin = sess.Values["admin"]
}
return c.NoContent(http.StatusOK)
}, ValidSession, RefreshSession)
app.E.ServeHTTP(rec, req)
if !profileCalled {
t.Fatal("profile handler was not called")
}
if profileUsername != "testuser" {
t.Errorf("expected username 'testuser', got '%v'", profileUsername)
}
if profileAdmin != true {
t.Errorf("expected admin true, got '%v'", profileAdmin)
}
if sessionCookie.Value != sessionToken {
t.Errorf("expected cookie value '%s', got '%s'", sessionToken, sessionCookie.Value)
}
if sessionCookie.MaxAge != 86400*7 {
t.Errorf("expected cookie MaxAge %d, got %d", 86400*7, sessionCookie.MaxAge)
}
}
var _ = regexp.MustCompile("")

41
main.go
View File

@ -29,7 +29,7 @@ var (
appVersion = "development"
gitCommit = "N/A"
gitRef = "N/A"
buildTime = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05"))
buildTime = time.Now().UTC().Format("01-02-2006 15:04:05")
// configuration variables
flagDisableLogin = false
flagBindAddress = "0.0.0.0:5000"
@ -137,6 +137,17 @@ func init() {
util.BasePath = util.ParseBasePath(flagBasePath)
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
util.CurrentAuthMethod = util.ParseAuthMethod()
if util.CurrentAuthMethod == util.AuthMethodGitHub {
util.CurrentGitHubConfig = util.ParseGitHubAuthConfig()
if err := util.ValidateGitHubAuthConfig(util.CurrentGitHubConfig); err != nil {
log.Fatalf("GitHub OAuth configuration is invalid: %v", err)
}
if err := handler.ApplyGitHubAuthConfig(util.CurrentGitHubConfig); err != nil {
log.Fatalf("Failed to apply GitHub OAuth configuration: %v", err)
}
}
lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO"))
// print only if log level is INFO or lower
@ -174,6 +185,7 @@ func main() {
extraData["gitCommit"] = gitCommit
extraData["basePath"] = util.BasePath
extraData["loginDisabled"] = flagDisableLogin
extraData["authMethod"] = util.CurrentAuthMethod
// strip the "templates/" prefix from the embedded directory so files can be read by their direct name (e.g.
// "base.html" instead of "templates/base.html")
@ -204,15 +216,24 @@ func main() {
if !util.DisableLogin {
app.GET(util.BasePath+"/login", handler.LoginPage())
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/get-users", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
if util.CurrentAuthMethod == util.AuthMethodGitHub {
app.GET(util.BasePath+"/auth/github/start", handler.GitHubStart())
app.GET(util.BasePath+"/auth/github/callback", handler.GitHubCallback(db))
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
} else {
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
app.GET(util.BasePath+"/get-users", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin)
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
}
}
var sendmail emailer.Emailer