From 2981a5ed1a3884ad7e3167ff8548441010c0681c Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 18 Jul 2020 10:18:47 +0100 Subject: [PATCH] Integrate HTPasswdValidator into OAuth2 Proxy --- htpasswd.go | 72 ------------------------------------------------ htpasswd_test.go | 38 ------------------------- main.go | 9 ------ oauthproxy.go | 32 +++++++++++++-------- 4 files changed, 21 insertions(+), 130 deletions(-) delete mode 100644 htpasswd.go delete mode 100644 htpasswd_test.go diff --git a/htpasswd.go b/htpasswd.go deleted file mode 100644 index 670aa729..00000000 --- a/htpasswd.go +++ /dev/null @@ -1,72 +0,0 @@ -package main - -import ( - "crypto/sha1" - "encoding/base64" - "encoding/csv" - "io" - "os" - - "github.com/oauth2-proxy/oauth2-proxy/pkg/logger" - "golang.org/x/crypto/bcrypt" -) - -// Lookup passwords in a htpasswd file -// Passwords must be generated with -B for bcrypt or -s for SHA1. - -// HtpasswdFile represents the structure of an htpasswd file -type HtpasswdFile struct { - Users map[string]string -} - -// NewHtpasswdFromFile constructs an HtpasswdFile from the file at the path given -func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { - r, err := os.Open(path) - if err != nil { - return nil, err - } - defer r.Close() - return NewHtpasswd(r) -} - -// NewHtpasswd consctructs an HtpasswdFile from an io.Reader (opened file) -func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { - csvReader := csv.NewReader(file) - csvReader.Comma = ':' - csvReader.Comment = '#' - csvReader.TrimLeadingSpace = true - - records, err := csvReader.ReadAll() - if err != nil { - return nil, err - } - h := &HtpasswdFile{Users: make(map[string]string)} - for _, record := range records { - h.Users[record[0]] = record[1] - } - return h, nil -} - -// Validate checks a users password against the HtpasswdFile entries -func (h *HtpasswdFile) Validate(user string, password string) bool { - realPassword, exists := h.Users[user] - if !exists { - return false - } - - shaPrefix := realPassword[:5] - if shaPrefix == "{SHA}" { - shaValue := realPassword[5:] - d := sha1.New() - d.Write([]byte(password)) - return shaValue == base64.StdEncoding.EncodeToString(d.Sum(nil)) - } - - bcryptPrefix := realPassword[:4] - if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" { - return bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password)) == nil - } - - logger.Printf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user) - return false -} diff --git a/htpasswd_test.go b/htpasswd_test.go deleted file mode 100644 index 7a043e46..00000000 --- a/htpasswd_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package main - -import ( - "bytes" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/bcrypt" -) - -func TestSHA(t *testing.T) { - file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n")) - h, err := NewHtpasswd(file) - assert.Equal(t, err, nil) - - valid := h.Validate("testuser", "asdf") - assert.Equal(t, valid, true) -} - -func TestBcrypt(t *testing.T) { - hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) - assert.Equal(t, err, nil) - hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) - assert.Equal(t, err, nil) - - contents := fmt.Sprintf("testuser1:%s\ntestuser2:%s\n", hash1, hash2) - file := bytes.NewBuffer([]byte(contents)) - - h, err := NewHtpasswd(file) - assert.Equal(t, err, nil) - - valid := h.Validate("testuser1", "password") - assert.Equal(t, valid, true) - - valid = h.Validate("testuser2", "top-secret") - assert.Equal(t, valid, true) -} diff --git a/main.go b/main.go index 4ee6af4b..b42c2c05 100644 --- a/main.go +++ b/main.go @@ -66,15 +66,6 @@ func main() { } } - if opts.HtpasswdFile != "" { - logger.Printf("using htpasswd file %s", opts.HtpasswdFile) - oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(opts.HtpasswdFile) - oauthproxy.DisplayHtpasswdForm = opts.DisplayHtpasswdForm - if err != nil { - logger.Fatalf("FATAL: unable to open %s %s", opts.HtpasswdFile, err) - } - } - rand.Seed(time.Now().UnixNano()) chain := alice.New() diff --git a/oauthproxy.go b/oauthproxy.go index c05c793d..4b310b9b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -22,6 +22,7 @@ import ( ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" + "github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" "github.com/oauth2-proxy/oauth2-proxy/pkg/ip" @@ -96,8 +97,8 @@ type OAuthProxy struct { sessionStore sessionsapi.SessionStore ProxyPrefix string SignInMessage string - HtpasswdFile *HtpasswdFile - DisplayHtpasswdForm bool + basicAuthValidator basic.Validator + displayHtpasswdForm bool serveMux http.Handler SetXAuthRequest bool PassBasicAuth bool @@ -314,6 +315,16 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } } + var basicAuthValidator basic.Validator + if opts.HtpasswdFile != "" { + logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) + var err error + basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) + if err != nil { + return nil, fmt.Errorf("could not load htpasswdfile: %v", err) + } + } + return &OAuthProxy{ CookieName: opts.Cookie.Name, CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), @@ -364,6 +375,9 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr trustedIPs: trustedIPs, Banner: opts.Banner, Footer: opts.Footer, + + basicAuthValidator: basicAuthValidator, + displayHtpasswdForm: basicAuthValidator != nil, }, nil } @@ -386,10 +400,6 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { return u.String() } -func (p *OAuthProxy) displayCustomLoginForm() bool { - return p.HtpasswdFile != nil && p.DisplayHtpasswdForm -} - func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) { if code == "" { return nil, errors.New("missing code") @@ -526,7 +536,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code }{ ProviderName: p.provider.Data().ProviderName, SignInMessage: template.HTML(p.SignInMessage), - CustomLogin: p.displayCustomLoginForm(), + CustomLogin: p.displayHtpasswdForm, Redirect: redirectURL, Version: VERSION, ProxyPrefix: p.ProxyPrefix, @@ -540,7 +550,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code // ManualSignIn handles basic auth logins to the proxy func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { - if req.Method != "POST" || p.HtpasswdFile == nil { + if req.Method != "POST" || p.basicAuthValidator == nil { return "", false } user := req.FormValue("username") @@ -549,7 +559,7 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st return "", false } // check auth - if p.HtpasswdFile.Validate(user, passwd) { + if p.basicAuthValidator.Validate(user, passwd) { logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via HtpasswdFile") return user, true } @@ -1159,7 +1169,7 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { // CheckBasicAuth checks the requests Authorization header for basic auth // credentials and authenticates these against the proxies HtpasswdFile func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { - if p.HtpasswdFile == nil { + if p.basicAuthValidator == nil { return nil, nil } auth := req.Header.Get("Authorization") @@ -1178,7 +1188,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionStat if len(pair) != 2 { return nil, fmt.Errorf("invalid format %s", b) } - if p.HtpasswdFile.Validate(pair[0], pair[1]) { + if p.basicAuthValidator.Validate(pair[0], pair[1]) { logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") return &sessionsapi.SessionState{User: pair[0]}, nil }