Integrate HTPasswdValidator into OAuth2 Proxy
This commit is contained in:
		
							parent
							
								
									7d8ee61254
								
							
						
					
					
						commit
						2981a5ed1a
					
				
							
								
								
									
										72
									
								
								htpasswd.go
								
								
								
								
							
							
						
						
									
										72
									
								
								htpasswd.go
								
								
								
								
							|  | @ -1,72 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"crypto/sha1" |  | ||||||
| 	"encoding/base64" |  | ||||||
| 	"encoding/csv" |  | ||||||
| 	"io" |  | ||||||
| 	"os" |  | ||||||
| 
 |  | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" |  | ||||||
| 	"golang.org/x/crypto/bcrypt" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| // Lookup passwords in a htpasswd file
 |  | ||||||
| // Passwords must be generated with -B for bcrypt or -s for SHA1.
 |  | ||||||
| 
 |  | ||||||
| // HtpasswdFile represents the structure of an htpasswd file
 |  | ||||||
| type HtpasswdFile struct { |  | ||||||
| 	Users map[string]string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NewHtpasswdFromFile constructs an HtpasswdFile from the file at the path given
 |  | ||||||
| func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { |  | ||||||
| 	r, err := os.Open(path) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	defer r.Close() |  | ||||||
| 	return NewHtpasswd(r) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // NewHtpasswd  consctructs an HtpasswdFile from an io.Reader (opened file)
 |  | ||||||
| func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { |  | ||||||
| 	csvReader := csv.NewReader(file) |  | ||||||
| 	csvReader.Comma = ':' |  | ||||||
| 	csvReader.Comment = '#' |  | ||||||
| 	csvReader.TrimLeadingSpace = true |  | ||||||
| 
 |  | ||||||
| 	records, err := csvReader.ReadAll() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 	h := &HtpasswdFile{Users: make(map[string]string)} |  | ||||||
| 	for _, record := range records { |  | ||||||
| 		h.Users[record[0]] = record[1] |  | ||||||
| 	} |  | ||||||
| 	return h, nil |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Validate checks a users password against the HtpasswdFile entries
 |  | ||||||
| func (h *HtpasswdFile) Validate(user string, password string) bool { |  | ||||||
| 	realPassword, exists := h.Users[user] |  | ||||||
| 	if !exists { |  | ||||||
| 		return false |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	shaPrefix := realPassword[:5] |  | ||||||
| 	if shaPrefix == "{SHA}" { |  | ||||||
| 		shaValue := realPassword[5:] |  | ||||||
| 		d := sha1.New() |  | ||||||
| 		d.Write([]byte(password)) |  | ||||||
| 		return shaValue == base64.StdEncoding.EncodeToString(d.Sum(nil)) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	bcryptPrefix := realPassword[:4] |  | ||||||
| 	if bcryptPrefix == "$2a$" || bcryptPrefix == "$2b$" || bcryptPrefix == "$2x$" || bcryptPrefix == "$2y$" { |  | ||||||
| 		return bcrypt.CompareHashAndPassword([]byte(realPassword), []byte(password)) == nil |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	logger.Printf("Invalid htpasswd entry for %s. Must be a SHA or bcrypt entry.", user) |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
|  | @ -1,38 +0,0 @@ | ||||||
| package main |  | ||||||
| 
 |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"fmt" |  | ||||||
| 	"testing" |  | ||||||
| 
 |  | ||||||
| 	"github.com/stretchr/testify/assert" |  | ||||||
| 	"golang.org/x/crypto/bcrypt" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func TestSHA(t *testing.T) { |  | ||||||
| 	file := bytes.NewBuffer([]byte("testuser:{SHA}PaVBVZkYqAjCQCu6UBL2xgsnZhw=\n")) |  | ||||||
| 	h, err := NewHtpasswd(file) |  | ||||||
| 	assert.Equal(t, err, nil) |  | ||||||
| 
 |  | ||||||
| 	valid := h.Validate("testuser", "asdf") |  | ||||||
| 	assert.Equal(t, valid, true) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestBcrypt(t *testing.T) { |  | ||||||
| 	hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) |  | ||||||
| 	assert.Equal(t, err, nil) |  | ||||||
| 	hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) |  | ||||||
| 	assert.Equal(t, err, nil) |  | ||||||
| 
 |  | ||||||
| 	contents := fmt.Sprintf("testuser1:%s\ntestuser2:%s\n", hash1, hash2) |  | ||||||
| 	file := bytes.NewBuffer([]byte(contents)) |  | ||||||
| 
 |  | ||||||
| 	h, err := NewHtpasswd(file) |  | ||||||
| 	assert.Equal(t, err, nil) |  | ||||||
| 
 |  | ||||||
| 	valid := h.Validate("testuser1", "password") |  | ||||||
| 	assert.Equal(t, valid, true) |  | ||||||
| 
 |  | ||||||
| 	valid = h.Validate("testuser2", "top-secret") |  | ||||||
| 	assert.Equal(t, valid, true) |  | ||||||
| } |  | ||||||
							
								
								
									
										9
									
								
								main.go
								
								
								
								
							
							
						
						
									
										9
									
								
								main.go
								
								
								
								
							|  | @ -66,15 +66,6 @@ func main() { | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if opts.HtpasswdFile != "" { |  | ||||||
| 		logger.Printf("using htpasswd file %s", opts.HtpasswdFile) |  | ||||||
| 		oauthproxy.HtpasswdFile, err = NewHtpasswdFromFile(opts.HtpasswdFile) |  | ||||||
| 		oauthproxy.DisplayHtpasswdForm = opts.DisplayHtpasswdForm |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Fatalf("FATAL: unable to open %s %s", opts.HtpasswdFile, err) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	rand.Seed(time.Now().UnixNano()) | 	rand.Seed(time.Now().UnixNano()) | ||||||
| 
 | 
 | ||||||
| 	chain := alice.New() | 	chain := alice.New() | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ import ( | ||||||
| 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/cookies" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/ip" | 	"github.com/oauth2-proxy/oauth2-proxy/pkg/ip" | ||||||
|  | @ -96,8 +97,8 @@ type OAuthProxy struct { | ||||||
| 	sessionStore            sessionsapi.SessionStore | 	sessionStore            sessionsapi.SessionStore | ||||||
| 	ProxyPrefix             string | 	ProxyPrefix             string | ||||||
| 	SignInMessage           string | 	SignInMessage           string | ||||||
| 	HtpasswdFile            *HtpasswdFile | 	basicAuthValidator      basic.Validator | ||||||
| 	DisplayHtpasswdForm     bool | 	displayHtpasswdForm     bool | ||||||
| 	serveMux                http.Handler | 	serveMux                http.Handler | ||||||
| 	SetXAuthRequest         bool | 	SetXAuthRequest         bool | ||||||
| 	PassBasicAuth           bool | 	PassBasicAuth           bool | ||||||
|  | @ -314,6 +315,16 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	var basicAuthValidator basic.Validator | ||||||
|  | 	if opts.HtpasswdFile != "" { | ||||||
|  | 		logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) | ||||||
|  | 		var err error | ||||||
|  | 		basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("could not load htpasswdfile: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return &OAuthProxy{ | 	return &OAuthProxy{ | ||||||
| 		CookieName:     opts.Cookie.Name, | 		CookieName:     opts.Cookie.Name, | ||||||
| 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), | 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), | ||||||
|  | @ -364,6 +375,9 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | ||||||
| 		trustedIPs:              trustedIPs, | 		trustedIPs:              trustedIPs, | ||||||
| 		Banner:                  opts.Banner, | 		Banner:                  opts.Banner, | ||||||
| 		Footer:                  opts.Footer, | 		Footer:                  opts.Footer, | ||||||
|  | 
 | ||||||
|  | 		basicAuthValidator:  basicAuthValidator, | ||||||
|  | 		displayHtpasswdForm: basicAuthValidator != nil, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -386,10 +400,6 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { | ||||||
| 	return u.String() | 	return u.String() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OAuthProxy) displayCustomLoginForm() bool { |  | ||||||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) { | func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("missing code") | 		return nil, errors.New("missing code") | ||||||
|  | @ -526,7 +536,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 	}{ | 	}{ | ||||||
| 		ProviderName:  p.provider.Data().ProviderName, | 		ProviderName:  p.provider.Data().ProviderName, | ||||||
| 		SignInMessage: template.HTML(p.SignInMessage), | 		SignInMessage: template.HTML(p.SignInMessage), | ||||||
| 		CustomLogin:   p.displayCustomLoginForm(), | 		CustomLogin:   p.displayHtpasswdForm, | ||||||
| 		Redirect:      redirectURL, | 		Redirect:      redirectURL, | ||||||
| 		Version:       VERSION, | 		Version:       VERSION, | ||||||
| 		ProxyPrefix:   p.ProxyPrefix, | 		ProxyPrefix:   p.ProxyPrefix, | ||||||
|  | @ -540,7 +550,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code | ||||||
| 
 | 
 | ||||||
| // ManualSignIn handles basic auth logins to the proxy
 | // ManualSignIn handles basic auth logins to the proxy
 | ||||||
| func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { | func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) { | ||||||
| 	if req.Method != "POST" || p.HtpasswdFile == nil { | 	if req.Method != "POST" || p.basicAuthValidator == nil { | ||||||
| 		return "", false | 		return "", false | ||||||
| 	} | 	} | ||||||
| 	user := req.FormValue("username") | 	user := req.FormValue("username") | ||||||
|  | @ -549,7 +559,7 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st | ||||||
| 		return "", false | 		return "", false | ||||||
| 	} | 	} | ||||||
| 	// check auth
 | 	// check auth
 | ||||||
| 	if p.HtpasswdFile.Validate(user, passwd) { | 	if p.basicAuthValidator.Validate(user, passwd) { | ||||||
| 		logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via HtpasswdFile") | 		logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via HtpasswdFile") | ||||||
| 		return user, true | 		return user, true | ||||||
| 	} | 	} | ||||||
|  | @ -1159,7 +1169,7 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { | ||||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | // credentials and authenticates these against the proxies HtpasswdFile
 | ||||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { | func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	if p.HtpasswdFile == nil { | 	if p.basicAuthValidator == nil { | ||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
| 	auth := req.Header.Get("Authorization") | 	auth := req.Header.Get("Authorization") | ||||||
|  | @ -1178,7 +1188,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionStat | ||||||
| 	if len(pair) != 2 { | 	if len(pair) != 2 { | ||||||
| 		return nil, fmt.Errorf("invalid format %s", b) | 		return nil, fmt.Errorf("invalid format %s", b) | ||||||
| 	} | 	} | ||||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | 	if p.basicAuthValidator.Validate(pair[0], pair[1]) { | ||||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||||
| 		return &sessionsapi.SessionState{User: pair[0]}, nil | 		return &sessionsapi.SessionState{User: pair[0]}, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue