feat: add GitHub OAuth handlers and wiring
This commit is contained in:
parent
bdbb9250b8
commit
97f6ada0e0
5
go.mod
5
go.mod
|
|
@ -1,10 +1,9 @@
|
|||
module github.com/ngoduykhanh/wireguard-ui
|
||||
|
||||
go 1.21
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/glendc/go-external-ip v0.1.0
|
||||
|
||||
github.com/gorilla/sessions v1.2.2
|
||||
github.com/labstack/echo-contrib v0.15.0
|
||||
github.com/labstack/echo/v4 v4.11.4
|
||||
|
|
@ -22,6 +21,8 @@ require (
|
|||
gopkg.in/go-playground/validator.v9 v9.31.0
|
||||
)
|
||||
|
||||
require golang.org/x/oauth2 v0.36.0
|
||||
|
||||
require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
|
|
|
|||
4
go.sum
4
go.sum
|
|
@ -1,5 +1,3 @@
|
|||
github.com/NicoNex/echotron/v3 v3.27.0 h1:iq4BLPO+Dz1JHjh2HPk0D0NldAZSYcAjaOicgYEhUzw=
|
||||
github.com/NicoNex/echotron/v3 v3.27.0/go.mod h1:LpP5IyHw0y+DZUZMBgXEDAF9O8feXrQu7w7nlJzzoZI=
|
||||
github.com/coreos/bbolt v1.3.1-coreos.6.0.20180223184059-4f5275f4ebbf/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
|
|
@ -132,6 +130,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||
golang.org/x/net v0.0.0-20210504132125-bbd867fde50d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
|
||||
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
|
||||
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
|
|
|
|||
|
|
@ -0,0 +1,364 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/ngoduykhanh/wireguard-ui/model"
|
||||
"github.com/ngoduykhanh/wireguard-ui/store"
|
||||
"github.com/ngoduykhanh/wireguard-ui/util"
|
||||
)
|
||||
|
||||
var (
|
||||
GitHubOAuth2Config *oauth2.Config
|
||||
GitHubAllowedUsers []string
|
||||
GitHubAllowedOrgs []string
|
||||
GitHubAdminUsers []string
|
||||
GitHubUserInfoURL = "https://api.github.com/user"
|
||||
GitHubOrgsURL = "https://api.github.com/user/memberships/orgs"
|
||||
)
|
||||
|
||||
var githubLoginRegexp = regexp.MustCompile(`^\w[\w\-.]*$`)
|
||||
|
||||
var errGitHubUsernameConflict = errors.New("username is already bound to a different GitHub account")
|
||||
|
||||
func ApplyGitHubAuthConfig(config util.GitHubAuthConfig) error {
|
||||
clientSecret := strings.TrimSpace(config.ClientSecret)
|
||||
if clientSecret == "" && config.ClientSecretFile != "" {
|
||||
secret, err := os.ReadFile(config.ClientSecretFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientSecret = strings.TrimSpace(string(secret))
|
||||
}
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: []string{"read:user", "read:org"},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = append([]string(nil), config.AllowedUsers...)
|
||||
GitHubAllowedOrgs = append([]string(nil), config.AllowedOrgs...)
|
||||
GitHubAdminUsers = append([]string(nil), config.AdminUsers...)
|
||||
return nil
|
||||
}
|
||||
|
||||
type githubUserResponse struct {
|
||||
Login string `json:"login"`
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type githubOrgMembership struct {
|
||||
Organization struct {
|
||||
Login string `json:"login"`
|
||||
} `json:"organization"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func GitHubStart() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
next := c.QueryParam("next")
|
||||
if next == "" {
|
||||
next = "/"
|
||||
}
|
||||
if !isSafeNextURL(next) {
|
||||
next = "/"
|
||||
}
|
||||
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = next
|
||||
sess.Save(c.Request(), c.Response())
|
||||
|
||||
authURL := GitHubOAuth2Config.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||
return c.Redirect(http.StatusTemporaryRedirect, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
func GitHubCallback(db store.IStore) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
|
||||
state := c.QueryParam("state")
|
||||
expectedState, _ := sess.Values["oauth_state"].(string)
|
||||
if state == "" || expectedState == "" || state != expectedState {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "invalid oauth state"})
|
||||
}
|
||||
|
||||
expiresAt, _ := sess.Values["oauth_state_expires_at"].(int64)
|
||||
if time.Now().UTC().Unix() > expiresAt {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "oauth state expired"})
|
||||
}
|
||||
|
||||
next, _ := sess.Values["oauth_next"].(string)
|
||||
if !isSafeNextURL(next) {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "invalid redirect target"})
|
||||
}
|
||||
|
||||
code := c.QueryParam("code")
|
||||
if code == "" {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "missing authorization code"})
|
||||
}
|
||||
|
||||
token, err := GitHubOAuth2Config.Exchange(c.Request().Context(), code)
|
||||
if err != nil {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "failed to exchange code for token"})
|
||||
}
|
||||
|
||||
githubUser, err := fetchGitHubUser(token.AccessToken)
|
||||
if err != nil {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "failed to fetch GitHub user"})
|
||||
}
|
||||
|
||||
githubLogin := strings.ToLower(githubUser.Login)
|
||||
if !isValidGitHubLogin(githubLogin) {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, "invalid GitHub username"})
|
||||
}
|
||||
|
||||
isAllowed := isGitHubUserAllowed(githubLogin, token.AccessToken)
|
||||
if !isAllowed {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusForbidden, jsonHTTPResponse{false, "user not authorized"})
|
||||
}
|
||||
|
||||
isAdmin := isGitHubAdmin(githubLogin)
|
||||
|
||||
user, originalUsername, err := findOrMigrateGitHubUser(db, githubLogin, githubUser.ID, githubUser.Name)
|
||||
if err != nil {
|
||||
clearOAuthSession(c)
|
||||
if errors.Is(err, errGitHubUsernameConflict) {
|
||||
return c.JSON(http.StatusBadRequest, jsonHTTPResponse{false, err.Error()})
|
||||
}
|
||||
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to resolve GitHub user"})
|
||||
}
|
||||
|
||||
user.Username = githubLogin
|
||||
user.DisplayName = githubUser.Name
|
||||
user.AuthSource = "github"
|
||||
user.AuthSubject = fmt.Sprintf("%d", githubUser.ID)
|
||||
user.Password = ""
|
||||
user.PasswordHash = ""
|
||||
user.Admin = isAdmin
|
||||
|
||||
if originalUsername != "" && originalUsername != githubLogin {
|
||||
if err := db.ReplaceUser(originalUsername, *user); err != nil {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to update user"})
|
||||
}
|
||||
} else {
|
||||
if err := db.SaveUser(*user); err != nil {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to save user"})
|
||||
}
|
||||
}
|
||||
|
||||
rememberMe := getSessionRememberMe(sess)
|
||||
|
||||
if err := establishAuthenticatedSession(c, *user, rememberMe); err != nil {
|
||||
clearOAuthSession(c)
|
||||
return c.JSON(http.StatusInternalServerError, jsonHTTPResponse{false, "failed to establish session"})
|
||||
}
|
||||
|
||||
clearOAuthSession(c)
|
||||
|
||||
return c.Redirect(http.StatusTemporaryRedirect, next)
|
||||
}
|
||||
}
|
||||
|
||||
func generateOAuthState() string {
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func isSafeNextURL(next string) bool {
|
||||
if next == "" {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(next)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if u.Opaque != "" && !strings.HasPrefix(next, "/") {
|
||||
return false
|
||||
}
|
||||
if u.Host != "" {
|
||||
return false
|
||||
}
|
||||
if !strings.HasPrefix(next, "/") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidGitHubLogin(login string) bool {
|
||||
if login == "" {
|
||||
return false
|
||||
}
|
||||
return githubLoginRegexp.MatchString(login)
|
||||
}
|
||||
|
||||
func isGitHubUserAllowed(login string, accessToken string) bool {
|
||||
loginLower := strings.ToLower(login)
|
||||
if len(GitHubAllowedUsers) > 0 {
|
||||
for _, allowed := range GitHubAllowedUsers {
|
||||
if strings.ToLower(allowed) == loginLower {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(GitHubAllowedOrgs) > 0 {
|
||||
orgs, err := fetchGitHubOrgs(accessToken)
|
||||
if err == nil {
|
||||
for _, org := range orgs {
|
||||
for _, allowedOrg := range GitHubAllowedOrgs {
|
||||
if strings.ToLower(org) == strings.ToLower(allowedOrg) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isGitHubAdmin(login string) bool {
|
||||
loginLower := strings.ToLower(login)
|
||||
for _, admin := range GitHubAdminUsers {
|
||||
if strings.ToLower(admin) == loginLower {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func fetchGitHubUser(accessToken string) (*githubUserResponse, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, GitHubUserInfoURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("github api returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var user githubUserResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func fetchGitHubOrgs(accessToken string) ([]string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, GitHubOrgsURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("github api returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var memberships []githubOrgMembership
|
||||
if err := json.NewDecoder(resp.Body).Decode(&memberships); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var orgs []string
|
||||
for _, m := range memberships {
|
||||
if m.State == "active" {
|
||||
orgs = append(orgs, strings.ToLower(m.Organization.Login))
|
||||
}
|
||||
}
|
||||
|
||||
return orgs, nil
|
||||
}
|
||||
|
||||
func clearOAuthSession(c echo.Context) {
|
||||
sess, _ := session.Get("session", c)
|
||||
delete(sess.Values, "oauth_state")
|
||||
delete(sess.Values, "oauth_state_expires_at")
|
||||
delete(sess.Values, "oauth_next")
|
||||
sess.Save(c.Request(), c.Response())
|
||||
}
|
||||
|
||||
func getSessionRememberMe(sess *sessions.Session) bool {
|
||||
maxAge, ok := sess.Values["max_age"].(int)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return maxAge > 0
|
||||
}
|
||||
|
||||
func findOrMigrateGitHubUser(db store.IStore, githubLogin string, githubID int64, displayName string) (*model.User, string, error) {
|
||||
existingUser, err := db.GetUserByAuthIdentity("github", fmt.Sprintf("%d", githubID))
|
||||
if err == nil {
|
||||
return &existingUser, existingUser.Username, nil
|
||||
}
|
||||
|
||||
localUser, err := db.GetUserByName(githubLogin)
|
||||
if err == nil {
|
||||
if localUser.AuthSource == "local" && localUser.AuthSubject == "" {
|
||||
return &localUser, localUser.Username, nil
|
||||
}
|
||||
if localUser.AuthSource == "github" && localUser.AuthSubject != "" && localUser.AuthSubject != fmt.Sprintf("%d", githubID) {
|
||||
return nil, "", errGitHubUsernameConflict
|
||||
}
|
||||
return &localUser, localUser.Username, nil
|
||||
}
|
||||
|
||||
newUser := model.User{
|
||||
Username: githubLogin,
|
||||
DisplayName: displayName,
|
||||
}
|
||||
|
||||
return &newUser, "", nil
|
||||
}
|
||||
|
|
@ -0,0 +1,892 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/ngoduykhanh/wireguard-ui/model"
|
||||
"github.com/ngoduykhanh/wireguard-ui/util"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
githubAPIHandler http.Handler
|
||||
githubTokenURL string
|
||||
githubUserURL string
|
||||
githubOrgsURL string
|
||||
)
|
||||
|
||||
func init() {
|
||||
util.SessionMaxDuration = 86400 * 90
|
||||
util.DBUsersToCRC32 = map[string]uint32{}
|
||||
}
|
||||
|
||||
func getTestUserCRC32(user model.User) uint32 {
|
||||
h := crc32.NewIEEE()
|
||||
writeHashField := func(h io.Writer, value string) {
|
||||
var length [4]byte
|
||||
binary.BigEndian.PutUint32(length[:], uint32(len(value)))
|
||||
_, _ = h.Write(length[:])
|
||||
_, _ = h.Write([]byte(value))
|
||||
}
|
||||
writeHashField(h, user.Username)
|
||||
if user.Admin {
|
||||
writeHashField(h, "1")
|
||||
} else {
|
||||
writeHashField(h, "0")
|
||||
}
|
||||
writeHashField(h, user.Password)
|
||||
writeHashField(h, user.PasswordHash)
|
||||
writeHashField(h, user.AuthSource)
|
||||
writeHashField(h, user.AuthSubject)
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
func setupTestServer() (*httptest.Server, *http.ServeMux) {
|
||||
mux := http.NewServeMux()
|
||||
server := httptest.NewServer(mux)
|
||||
githubTokenURL = server.URL + "/oauth/token"
|
||||
githubUserURL = server.URL + "/user"
|
||||
githubOrgsURL = server.URL + "/user/memberships/orgs"
|
||||
return server, mux
|
||||
}
|
||||
|
||||
func setupGitHubAPIHandler(mux *http.ServeMux, userLogin string, userID int64, orgs []string, isOrgMember bool) {
|
||||
mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.FormValue("grant_type") != "authorization_code" {
|
||||
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"access_token": "test-access-token",
|
||||
"token_type": "Bearer",
|
||||
"scope": "read:user,user:email",
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("/user", func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader != "Bearer test-access-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"login": userLogin,
|
||||
"id": userID,
|
||||
"name": "Test User",
|
||||
})
|
||||
})
|
||||
mux.HandleFunc("/user/memberships/orgs", func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader != "Bearer test-access-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if isOrgMember {
|
||||
var orgsData []map[string]interface{}
|
||||
for _, org := range orgs {
|
||||
orgsData = append(orgsData, map[string]interface{}{
|
||||
"organization": map[string]interface{}{
|
||||
"login": org,
|
||||
},
|
||||
"state": "active",
|
||||
})
|
||||
}
|
||||
json.NewEncoder(w).Encode(orgsData)
|
||||
} else {
|
||||
json.NewEncoder(w).Encode([]map[string]interface{}{})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyGitHubAuthConfig_SetsRuntimeGlobals(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
secretFile := filepath.Join(tmpDir, "github-secret")
|
||||
if err := os.WriteFile(secretFile, []byte("test-secret"), 0600); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
config := util.GitHubAuthConfig{
|
||||
ClientID: "client-id",
|
||||
ClientSecretFile: secretFile,
|
||||
RedirectURL: "https://example.com/auth/github/callback",
|
||||
AllowedUsers: []string{"user1"},
|
||||
AllowedOrgs: []string{"org1"},
|
||||
AdminUsers: []string{"admin1"},
|
||||
}
|
||||
|
||||
if err := ApplyGitHubAuthConfig(config); err != nil {
|
||||
t.Fatalf("expected ApplyGitHubAuthConfig to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if GitHubOAuth2Config == nil {
|
||||
t.Fatal("expected GitHubOAuth2Config to be initialized")
|
||||
}
|
||||
if GitHubOAuth2Config.ClientID != "client-id" {
|
||||
t.Fatalf("expected client id to be bridged, got %q", GitHubOAuth2Config.ClientID)
|
||||
}
|
||||
if GitHubOAuth2Config.ClientSecret != "test-secret" {
|
||||
t.Fatalf("expected client secret to be loaded from file, got %q", GitHubOAuth2Config.ClientSecret)
|
||||
}
|
||||
if GitHubOAuth2Config.RedirectURL != "https://example.com/auth/github/callback" {
|
||||
t.Fatalf("expected redirect url to be bridged, got %q", GitHubOAuth2Config.RedirectURL)
|
||||
}
|
||||
if len(GitHubAllowedUsers) != 1 || GitHubAllowedUsers[0] != "user1" {
|
||||
t.Fatalf("expected allowed users to be bridged, got %#v", GitHubAllowedUsers)
|
||||
}
|
||||
if len(GitHubAllowedOrgs) != 1 || GitHubAllowedOrgs[0] != "org1" {
|
||||
t.Fatalf("expected allowed orgs to be bridged, got %#v", GitHubAllowedOrgs)
|
||||
}
|
||||
if len(GitHubAdminUsers) != 1 || GitHubAdminUsers[0] != "admin1" {
|
||||
t.Fatalf("expected admin users to be bridged, got %#v", GitHubAdminUsers)
|
||||
}
|
||||
if GitHubOAuth2Config.Endpoint.AuthURL == "" || GitHubOAuth2Config.Endpoint.TokenURL == "" {
|
||||
t.Fatal("expected GitHub OAuth endpoints to be initialized")
|
||||
}
|
||||
if len(GitHubOAuth2Config.Scopes) != 2 || GitHubOAuth2Config.Scopes[0] != "read:user" || GitHubOAuth2Config.Scopes[1] != "read:org" {
|
||||
t.Fatalf("expected GitHub OAuth scopes to be initialized, got %#v", GitHubOAuth2Config.Scopes)
|
||||
}
|
||||
}
|
||||
|
||||
type testApp struct {
|
||||
E *echo.Echo
|
||||
CookieStore *sessions.CookieStore
|
||||
}
|
||||
|
||||
func newTestApp() *testApp {
|
||||
e := echo.New()
|
||||
secret := make([]byte, 64)
|
||||
rand.Read(secret)
|
||||
cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
|
||||
cookieStore.Options.Path = "/"
|
||||
cookieStore.Options.HttpOnly = true
|
||||
cookieStore.MaxAge(86400 * 7)
|
||||
e.Use(session.Middleware(cookieStore))
|
||||
return &testApp{E: e, CookieStore: cookieStore}
|
||||
}
|
||||
|
||||
func newTestAppWithFixedSecret() *testApp {
|
||||
e := echo.New()
|
||||
secret, _ := hex.DecodeString("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
|
||||
cookieStore := sessions.NewCookieStore(secret[:32], secret[32:])
|
||||
cookieStore.Options.Path = "/"
|
||||
cookieStore.Options.HttpOnly = true
|
||||
cookieStore.MaxAge(86400 * 7)
|
||||
e.Use(session.Middleware(cookieStore))
|
||||
return &testApp{E: e, CookieStore: cookieStore}
|
||||
}
|
||||
|
||||
type mockStore struct {
|
||||
users []model.User
|
||||
savedUsers map[string]model.User
|
||||
deletedUsers []string
|
||||
}
|
||||
|
||||
func (m *mockStore) Init() error { return nil }
|
||||
func (m *mockStore) GetUsers() ([]model.User, error) { return m.users, nil }
|
||||
func (m *mockStore) GetUserByName(username string) (model.User, error) {
|
||||
for _, u := range m.users {
|
||||
if u.Username == username {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
return model.User{}, fmt.Errorf("not found")
|
||||
}
|
||||
func (m *mockStore) GetUserByAuthIdentity(authSource, authSubject string) (model.User, error) {
|
||||
for _, u := range m.users {
|
||||
if u.AuthSource == authSource && u.AuthSubject == authSubject {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
return model.User{}, fmt.Errorf("not found")
|
||||
}
|
||||
func (m *mockStore) SaveUser(user model.User) error {
|
||||
if m.savedUsers == nil {
|
||||
m.savedUsers = make(map[string]model.User)
|
||||
}
|
||||
m.savedUsers[user.Username] = user
|
||||
for i, u := range m.users {
|
||||
if u.Username == user.Username {
|
||||
m.users[i] = user
|
||||
return nil
|
||||
}
|
||||
}
|
||||
m.users = append(m.users, user)
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) ReplaceUser(oldUsername string, user model.User) error {
|
||||
return m.SaveUser(user)
|
||||
}
|
||||
func (m *mockStore) DeleteUser(username string) error {
|
||||
m.deletedUsers = append(m.deletedUsers, username)
|
||||
return nil
|
||||
}
|
||||
func (m *mockStore) GetGlobalSettings() (model.GlobalSetting, error) {
|
||||
return model.GlobalSetting{}, nil
|
||||
}
|
||||
func (m *mockStore) GetServer() (model.Server, error) { return model.Server{}, nil }
|
||||
func (m *mockStore) GetClients(bool) ([]model.ClientData, error) { return nil, nil }
|
||||
func (m *mockStore) GetClientByID(string, model.QRCodeSettings) (model.ClientData, error) {
|
||||
return model.ClientData{}, nil
|
||||
}
|
||||
func (m *mockStore) SaveClient(model.Client) error { return nil }
|
||||
func (m *mockStore) DeleteClient(string) error { return nil }
|
||||
func (m *mockStore) SaveServerInterface(model.ServerInterface) error { return nil }
|
||||
func (m *mockStore) SaveServerKeyPair(model.ServerKeypair) error { return nil }
|
||||
func (m *mockStore) SaveGlobalSettings(model.GlobalSetting) error { return nil }
|
||||
func (m *mockStore) GetWakeOnLanHosts() ([]model.WakeOnLanHost, error) { return nil, nil }
|
||||
func (m *mockStore) GetWakeOnLanHost(string) (*model.WakeOnLanHost, error) { return nil, nil }
|
||||
func (m *mockStore) DeleteWakeOnHostLanHost(string) error { return nil }
|
||||
func (m *mockStore) SaveWakeOnLanHost(model.WakeOnLanHost) error { return nil }
|
||||
func (m *mockStore) DeleteWakeOnHost(model.WakeOnLanHost) error { return nil }
|
||||
func (m *mockStore) GetPath() string { return "" }
|
||||
func (m *mockStore) SaveHashes(model.ClientServerHashes) error { return nil }
|
||||
func (m *mockStore) GetHashes() (model.ClientServerHashes, error) {
|
||||
return model.ClientServerHashes{}, nil
|
||||
}
|
||||
|
||||
func TestGitHubCallback_AllowsWhitelistedUser(t *testing.T) {
|
||||
server, mux := setupTestServer()
|
||||
defer server.Close()
|
||||
setupGitHubAPIHandler(mux, "whitelistuser", 1001, nil, false)
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{"whitelistuser"}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
GitHubUserInfoURL = githubUserURL
|
||||
GitHubOrgsURL = githubOrgsURL
|
||||
|
||||
mockDB := &mockStore{}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("start handler failed: %d", rec.Code)
|
||||
}
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusTemporaryRedirect {
|
||||
t.Fatalf("expected redirect 307, got %d. body: %s, location: %s", rec.Code, rec.Body.String(), rec.Header().Get("Location"))
|
||||
}
|
||||
|
||||
if rec.Header().Get("Location") != "/" {
|
||||
t.Errorf("expected redirect to '/', got '%s'", rec.Header().Get("Location"))
|
||||
}
|
||||
|
||||
sessionCookies := rec.Result().Cookies()
|
||||
var sessionToken string
|
||||
for _, c := range sessionCookies {
|
||||
if c.Name == "session_token" {
|
||||
sessionToken = c.Value
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionToken == "" {
|
||||
t.Fatal("expected session_token cookie to be set")
|
||||
}
|
||||
|
||||
savedUser, ok := mockDB.savedUsers["whitelistuser"]
|
||||
if !ok {
|
||||
t.Fatal("expected user to be saved to store")
|
||||
}
|
||||
if savedUser.AuthSource != "github" {
|
||||
t.Errorf("expected AuthSource 'github', got '%s'", savedUser.AuthSource)
|
||||
}
|
||||
if savedUser.AuthSubject != "1001" {
|
||||
t.Errorf("expected AuthSubject '1001', got '%s'", savedUser.AuthSubject)
|
||||
}
|
||||
if savedUser.Password != "" {
|
||||
t.Errorf("expected empty Password, got '%s'", savedUser.Password)
|
||||
}
|
||||
if savedUser.PasswordHash != "" {
|
||||
t.Errorf("expected empty PasswordHash, got '%s'", savedUser.PasswordHash)
|
||||
}
|
||||
if savedUser.Admin {
|
||||
t.Errorf("expected Admin false, got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_AllowsOrgMember(t *testing.T) {
|
||||
server, mux := setupTestServer()
|
||||
defer server.Close()
|
||||
setupGitHubAPIHandler(mux, "orgmember", 1002, []string{"allowed-org"}, true)
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{}
|
||||
GitHubAllowedOrgs = []string{"allowed-org"}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
GitHubUserInfoURL = githubUserURL
|
||||
GitHubOrgsURL = githubOrgsURL
|
||||
|
||||
mockDB := &mockStore{}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("start handler failed: %d", rec.Code)
|
||||
}
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusTemporaryRedirect {
|
||||
t.Fatalf("expected redirect 307, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_RejectsStateMismatch(t *testing.T) {
|
||||
server, _ := setupTestServer()
|
||||
defer server.Close()
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{"testuser"}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
mockDB := &mockStore{}
|
||||
app := newTestApp()
|
||||
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state=wrong-state", nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_RejectsExpiredState(t *testing.T) {
|
||||
server, _ := setupTestServer()
|
||||
defer server.Close()
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{"testuser"}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
mockDB := &mockStore{}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = time.Now().UTC().Add(-1 * time.Minute).Unix()
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_RejectsExternalNextURL(t *testing.T) {
|
||||
server, _ := setupTestServer()
|
||||
defer server.Close()
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{"testuser"}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
mockDB := &mockStore{}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "http://evil.com/redirect"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_RejectsUnsafeGitHubLogin(t *testing.T) {
|
||||
unsafeLogins := []string{
|
||||
"user with space",
|
||||
"user\nwith\nnewline",
|
||||
"",
|
||||
"-invalid",
|
||||
}
|
||||
|
||||
for _, unsafeLogin := range unsafeLogins {
|
||||
server, mux := setupTestServer()
|
||||
setupGitHubAPIHandler(mux, unsafeLogin, 1001, nil, false)
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{strings.ToLower(unsafeLogin)}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
GitHubUserInfoURL = githubUserURL
|
||||
GitHubOrgsURL = githubOrgsURL
|
||||
|
||||
mockDB := &mockStore{}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
server.Close()
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("login %q: expected status 400, got %d", unsafeLogin, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_MigratesLegacyLocalUser(t *testing.T) {
|
||||
server, mux := setupTestServer()
|
||||
defer server.Close()
|
||||
setupGitHubAPIHandler(mux, "legacyuser", 1001, nil, false)
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{"legacyuser"}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
GitHubUserInfoURL = githubUserURL
|
||||
GitHubOrgsURL = githubOrgsURL
|
||||
|
||||
mockDB := &mockStore{
|
||||
users: []model.User{
|
||||
{
|
||||
Username: "legacyuser",
|
||||
Password: "legacy-pass",
|
||||
PasswordHash: "",
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
Admin: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("start handler failed: %d", rec.Code)
|
||||
}
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusTemporaryRedirect {
|
||||
t.Fatalf("expected redirect 307, got %d", rec.Code)
|
||||
}
|
||||
|
||||
savedUser, ok := mockDB.savedUsers["legacyuser"]
|
||||
if !ok {
|
||||
t.Fatal("expected user to be saved to store")
|
||||
}
|
||||
if savedUser.AuthSource != "github" {
|
||||
t.Errorf("expected AuthSource 'github', got '%s'", savedUser.AuthSource)
|
||||
}
|
||||
if savedUser.AuthSubject != "1001" {
|
||||
t.Errorf("expected AuthSubject '1001', got '%s'", savedUser.AuthSubject)
|
||||
}
|
||||
if savedUser.Password != "" {
|
||||
t.Errorf("expected empty Password, got '%s'", savedUser.Password)
|
||||
}
|
||||
if savedUser.PasswordHash != "" {
|
||||
t.Errorf("expected empty PasswordHash, got '%s'", savedUser.PasswordHash)
|
||||
}
|
||||
if savedUser.DisplayName != "Test User" {
|
||||
t.Errorf("expected DisplayName 'Test User', got '%s'", savedUser.DisplayName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGitHubCallback_RejectsConflictingBoundUser(t *testing.T) {
|
||||
server, mux := setupTestServer()
|
||||
defer server.Close()
|
||||
setupGitHubAPIHandler(mux, "anotheruser", 1002, nil, false)
|
||||
|
||||
GitHubOAuth2Config = &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
ClientSecret: "test-client-secret",
|
||||
RedirectURL: "http://localhost:8080/github/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: server.URL + "/oauth/authorize",
|
||||
TokenURL: githubTokenURL,
|
||||
},
|
||||
}
|
||||
GitHubAllowedUsers = []string{"anotheruser"}
|
||||
GitHubAllowedOrgs = []string{}
|
||||
GitHubAdminUsers = []string{"admin"}
|
||||
|
||||
GitHubUserInfoURL = githubUserURL
|
||||
GitHubOrgsURL = githubOrgsURL
|
||||
|
||||
mockDB := &mockStore{
|
||||
users: []model.User{
|
||||
{
|
||||
Username: "anotheruser",
|
||||
Password: "local-pass",
|
||||
PasswordHash: "",
|
||||
AuthSource: "github",
|
||||
AuthSubject: "999999",
|
||||
Admin: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
app := newTestApp()
|
||||
|
||||
state := generateOAuthState()
|
||||
expiresAt := time.Now().UTC().Add(10 * time.Minute).Unix()
|
||||
|
||||
app.E.POST("/github/start", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
sess.Values["oauth_state"] = state
|
||||
sess.Values["oauth_state_expires_at"] = expiresAt
|
||||
sess.Values["oauth_next"] = "/"
|
||||
sess.Save(c.Request(), c.Response())
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
app.E.GET("/github/callback", GitHubCallback(mockDB))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/github/start", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/github/callback?code=test-code&state="+state, nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", rec.Code)
|
||||
}
|
||||
|
||||
if len(mockDB.savedUsers) > 0 {
|
||||
t.Errorf("expected no user to be saved, got %d saved users", len(mockDB.savedUsers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalLoginAndGitHubLoginShareSessionBuilder(t *testing.T) {
|
||||
testUser := model.User{
|
||||
Username: "testuser",
|
||||
Password: "testpass",
|
||||
PasswordHash: "",
|
||||
AuthSource: "local",
|
||||
AuthSubject: "",
|
||||
Admin: true,
|
||||
}
|
||||
util.DBUsersToCRC32["testuser"] = getTestUserCRC32(testUser)
|
||||
|
||||
mockDB := &mockStore{
|
||||
users: []model.User{testUser},
|
||||
}
|
||||
|
||||
app := newTestAppWithFixedSecret()
|
||||
app.E.POST("/login", Login(mockDB), ContentTypeJson)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(`{"username":"testuser","password":"testpass","rememberMe":true}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d. body: %s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
var sessionToken string
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range cookies {
|
||||
if c.Name == "session_token" {
|
||||
sessionToken = c.Value
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionToken == "" {
|
||||
t.Fatal("expected session_token cookie to be set")
|
||||
}
|
||||
|
||||
if sessionCookie == nil {
|
||||
t.Fatal("expected session_token cookie to be set")
|
||||
}
|
||||
if sessionCookie.Value == "" {
|
||||
t.Fatal("expected session_token cookie value to be non-empty")
|
||||
}
|
||||
if sessionCookie.MaxAge == 0 {
|
||||
t.Fatal("expected session_token cookie MaxAge to be set for rememberMe=true")
|
||||
}
|
||||
|
||||
req = httptest.NewRequest(http.MethodGet, "/profile", nil)
|
||||
for _, c := range cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
rec = httptest.NewRecorder()
|
||||
|
||||
profileCalled := false
|
||||
var profileUsername string
|
||||
var profileAdmin interface{}
|
||||
app.E.GET("/profile", func(c echo.Context) error {
|
||||
sess, _ := session.Get("session", c)
|
||||
profileCalled = true
|
||||
if sess != nil {
|
||||
profileUsername, _ = sess.Values["username"].(string)
|
||||
profileAdmin = sess.Values["admin"]
|
||||
}
|
||||
return c.NoContent(http.StatusOK)
|
||||
}, ValidSession, RefreshSession)
|
||||
app.E.ServeHTTP(rec, req)
|
||||
|
||||
if !profileCalled {
|
||||
t.Fatal("profile handler was not called")
|
||||
}
|
||||
|
||||
if profileUsername != "testuser" {
|
||||
t.Errorf("expected username 'testuser', got '%v'", profileUsername)
|
||||
}
|
||||
|
||||
if profileAdmin != true {
|
||||
t.Errorf("expected admin true, got '%v'", profileAdmin)
|
||||
}
|
||||
|
||||
if sessionCookie.Value != sessionToken {
|
||||
t.Errorf("expected cookie value '%s', got '%s'", sessionToken, sessionCookie.Value)
|
||||
}
|
||||
if sessionCookie.MaxAge != 86400*7 {
|
||||
t.Errorf("expected cookie MaxAge %d, got %d", 86400*7, sessionCookie.MaxAge)
|
||||
}
|
||||
}
|
||||
|
||||
var _ = regexp.MustCompile("")
|
||||
41
main.go
41
main.go
|
|
@ -29,7 +29,7 @@ var (
|
|||
appVersion = "development"
|
||||
gitCommit = "N/A"
|
||||
gitRef = "N/A"
|
||||
buildTime = fmt.Sprintf(time.Now().UTC().Format("01-02-2006 15:04:05"))
|
||||
buildTime = time.Now().UTC().Format("01-02-2006 15:04:05")
|
||||
// configuration variables
|
||||
flagDisableLogin = false
|
||||
flagBindAddress = "0.0.0.0:5000"
|
||||
|
|
@ -137,6 +137,17 @@ func init() {
|
|||
util.BasePath = util.ParseBasePath(flagBasePath)
|
||||
util.SubnetRanges = util.ParseSubnetRanges(flagSubnetRanges)
|
||||
|
||||
util.CurrentAuthMethod = util.ParseAuthMethod()
|
||||
if util.CurrentAuthMethod == util.AuthMethodGitHub {
|
||||
util.CurrentGitHubConfig = util.ParseGitHubAuthConfig()
|
||||
if err := util.ValidateGitHubAuthConfig(util.CurrentGitHubConfig); err != nil {
|
||||
log.Fatalf("GitHub OAuth configuration is invalid: %v", err)
|
||||
}
|
||||
if err := handler.ApplyGitHubAuthConfig(util.CurrentGitHubConfig); err != nil {
|
||||
log.Fatalf("Failed to apply GitHub OAuth configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
lvl, _ := util.ParseLogLevel(util.LookupEnvOrString(util.LogLevel, "INFO"))
|
||||
|
||||
// print only if log level is INFO or lower
|
||||
|
|
@ -174,6 +185,7 @@ func main() {
|
|||
extraData["gitCommit"] = gitCommit
|
||||
extraData["basePath"] = util.BasePath
|
||||
extraData["loginDisabled"] = flagDisableLogin
|
||||
extraData["authMethod"] = util.CurrentAuthMethod
|
||||
|
||||
// strip the "templates/" prefix from the embedded directory so files can be read by their direct name (e.g.
|
||||
// "base.html" instead of "templates/base.html")
|
||||
|
|
@ -204,15 +216,24 @@ func main() {
|
|||
|
||||
if !util.DisableLogin {
|
||||
app.GET(util.BasePath+"/login", handler.LoginPage())
|
||||
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
|
||||
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
|
||||
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
|
||||
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
|
||||
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
|
||||
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
|
||||
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
|
||||
app.GET(util.BasePath+"/get-users", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin)
|
||||
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
|
||||
|
||||
if util.CurrentAuthMethod == util.AuthMethodGitHub {
|
||||
app.GET(util.BasePath+"/auth/github/start", handler.GitHubStart())
|
||||
app.GET(util.BasePath+"/auth/github/callback", handler.GitHubCallback(db))
|
||||
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
|
||||
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
|
||||
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
|
||||
} else {
|
||||
app.POST(util.BasePath+"/login", handler.Login(db), handler.ContentTypeJson)
|
||||
app.GET(util.BasePath+"/logout", handler.Logout(), handler.ValidSession)
|
||||
app.GET(util.BasePath+"/profile", handler.LoadProfile(), handler.ValidSession, handler.RefreshSession)
|
||||
app.GET(util.BasePath+"/users-settings", handler.UsersSettings(), handler.ValidSession, handler.RefreshSession, handler.NeedsAdmin)
|
||||
app.POST(util.BasePath+"/update-user", handler.UpdateUser(db), handler.ValidSession, handler.ContentTypeJson)
|
||||
app.POST(util.BasePath+"/create-user", handler.CreateUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
|
||||
app.POST(util.BasePath+"/remove-user", handler.RemoveUser(db), handler.ValidSession, handler.ContentTypeJson, handler.NeedsAdmin)
|
||||
app.GET(util.BasePath+"/get-users", handler.GetUsers(db), handler.ValidSession, handler.NeedsAdmin)
|
||||
app.GET(util.BasePath+"/api/user/:username", handler.GetUser(db), handler.ValidSession)
|
||||
}
|
||||
}
|
||||
|
||||
var sendmail emailer.Emailer
|
||||
|
|
|
|||
Loading…
Reference in New Issue