Initialise SessionStore in Options
This commit is contained in:
		
							parent
							
								
									17e97ab884
								
							
						
					
					
						commit
						fbee5eae16
					
				|  | @ -16,7 +16,7 @@ import ( | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/yhat/wsutil" | 	"github.com/yhat/wsutil" | ||||||
| ) | ) | ||||||
|  | @ -75,6 +75,7 @@ type OAuthProxy struct { | ||||||
| 	redirectURL         *url.URL // the url to receive requests at
 | 	redirectURL         *url.URL // the url to receive requests at
 | ||||||
| 	whitelistDomains    []string | 	whitelistDomains    []string | ||||||
| 	provider            providers.Provider | 	provider            providers.Provider | ||||||
|  | 	sessionStore        sessionsapi.SessionStore | ||||||
| 	ProxyPrefix         string | 	ProxyPrefix         string | ||||||
| 	SignInMessage       string | 	SignInMessage       string | ||||||
| 	HtpasswdFile        *HtpasswdFile | 	HtpasswdFile        *HtpasswdFile | ||||||
|  | @ -249,6 +250,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | ||||||
| 
 | 
 | ||||||
| 		ProxyPrefix:        opts.ProxyPrefix, | 		ProxyPrefix:        opts.ProxyPrefix, | ||||||
| 		provider:           opts.provider, | 		provider:           opts.provider, | ||||||
|  | 		sessionStore:       opts.sessionStore, | ||||||
| 		serveMux:           serveMux, | 		serveMux:           serveMux, | ||||||
| 		redirectURL:        redirectURL, | 		redirectURL:        redirectURL, | ||||||
| 		whitelistDomains:   opts.WhitelistDomains, | 		whitelistDomains:   opts.WhitelistDomains, | ||||||
|  | @ -293,7 +295,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { | ||||||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { | func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("missing code") | 		return nil, errors.New("missing code") | ||||||
| 	} | 	} | ||||||
|  | @ -485,7 +487,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // LoadCookiedSession reads the user's authentication details from the request
 | // LoadCookiedSession reads the user's authentication details from the request
 | ||||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { | func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, time.Duration, error) { | ||||||
| 	var age time.Duration | 	var age time.Duration | ||||||
| 	c, err := loadCookie(req, p.CookieName) | 	c, err := loadCookie(req, p.CookieName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -507,7 +509,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionSta | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SaveSession creates a new session cookie value and sets this on the response
 | // SaveSession creates a new session cookie value and sets this on the response
 | ||||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|  | @ -694,7 +696,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	user, ok := p.ManualSignIn(rw, req) | 	user, ok := p.ManualSignIn(rw, req) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		session := &sessions.SessionState{User: user} | 		session := &sessionsapi.SessionState{User: user} | ||||||
| 		p.SaveSession(rw, req, session) | 		p.SaveSession(rw, req, session) | ||||||
| 		http.Redirect(rw, req, redirect, 302) | 		http.Redirect(rw, req, redirect, 302) | ||||||
| 	} else { | 	} else { | ||||||
|  | @ -945,7 +947,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 
 | 
 | ||||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | // credentials and authenticates these against the proxies HtpasswdFile
 | ||||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { | func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { | ||||||
| 	if p.HtpasswdFile == nil { | 	if p.HtpasswdFile == nil { | ||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  | @ -967,7 +969,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, | ||||||
| 	} | 	} | ||||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||||
| 		return &sessions.SessionState{User: pair[0]}, nil | 		return &sessionsapi.SessionState{User: pair[0]}, nil | ||||||
| 	} | 	} | ||||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
|  |  | ||||||
							
								
								
									
										16
									
								
								options.go
								
								
								
								
							
							
						
						
									
										16
									
								
								options.go
								
								
								
								
							|  | @ -19,6 +19,8 @@ import ( | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
|  | 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"gopkg.in/natefinch/lumberjack.v2" | 	"gopkg.in/natefinch/lumberjack.v2" | ||||||
| ) | ) | ||||||
|  | @ -111,6 +113,7 @@ type Options struct { | ||||||
| 	proxyURLs     []*url.URL | 	proxyURLs     []*url.URL | ||||||
| 	CompiledRegex []*regexp.Regexp | 	CompiledRegex []*regexp.Regexp | ||||||
| 	provider      providers.Provider | 	provider      providers.Provider | ||||||
|  | 	sessionStore  sessionsapi.SessionStore | ||||||
| 	signatureData *SignatureData | 	signatureData *SignatureData | ||||||
| 	oidcVerifier  *oidc.IDTokenVerifier | 	oidcVerifier  *oidc.IDTokenVerifier | ||||||
| } | } | ||||||
|  | @ -136,6 +139,9 @@ func NewOptions() *Options { | ||||||
| 			CookieExpire:   time.Duration(168) * time.Hour, | 			CookieExpire:   time.Duration(168) * time.Hour, | ||||||
| 			CookieRefresh:  time.Duration(0), | 			CookieRefresh:  time.Duration(0), | ||||||
| 		}, | 		}, | ||||||
|  | 		SessionOptions: options.SessionOptions{ | ||||||
|  | 			Type: "cookie", | ||||||
|  | 		}, | ||||||
| 		SetXAuthRequest:       false, | 		SetXAuthRequest:       false, | ||||||
| 		SkipAuthPreflight:     false, | 		SkipAuthPreflight:     false, | ||||||
| 		PassBasicAuth:         true, | 		PassBasicAuth:         true, | ||||||
|  | @ -283,9 +289,19 @@ func (o *Options) Validate() error { | ||||||
| 					"pass_access_token == true or "+ | 					"pass_access_token == true or "+ | ||||||
| 					"cookie_refresh != 0, but is %d bytes.%s", | 					"cookie_refresh != 0, but is %d bytes.%s", | ||||||
| 				len(secretBytes(o.CookieSecret)), suffix)) | 				len(secretBytes(o.CookieSecret)), suffix)) | ||||||
|  | 		} else { | ||||||
|  | 			// Enable encryption in the session store
 | ||||||
|  | 			o.EnableCipher = true | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions) | ||||||
|  | 	if err != nil { | ||||||
|  | 		msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err)) | ||||||
|  | 	} else { | ||||||
|  | 		o.sessionStore = sessionStore | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if o.CookieRefresh >= o.CookieExpire { | 	if o.CookieRefresh >= o.CookieExpire { | ||||||
| 		msgs = append(msgs, fmt.Sprintf( | 		msgs = append(msgs, fmt.Sprintf( | ||||||
| 			"cookie_refresh (%s) must be less than "+ | 			"cookie_refresh (%s) must be less than "+ | ||||||
|  |  | ||||||
|  | @ -3,6 +3,7 @@ package options | ||||||
| // SessionOptions contains configuration options for the SessionStore providers.
 | // SessionOptions contains configuration options for the SessionStore providers.
 | ||||||
| type SessionOptions struct { | type SessionOptions struct { | ||||||
| 	Type         string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | 	Type         string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | ||||||
|  | 	EnableCipher bool   // Allow the user to choose encryption or not
 | ||||||
| 	CookieStoreOptions | 	CookieStoreOptions | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -11,4 +12,6 @@ type SessionOptions struct { | ||||||
| var CookieSessionStoreType = "cookie" | var CookieSessionStoreType = "cookie" | ||||||
| 
 | 
 | ||||||
| // CookieStoreOptions contains configuration options for the CookieSessionStore.
 | // CookieStoreOptions contains configuration options for the CookieSessionStore.
 | ||||||
| type CookieStoreOptions struct{} | type CookieStoreOptions struct { | ||||||
|  | 	EnableCipher bool // Allow the user to choose encryption or not
 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -126,7 +126,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string, | ||||||
| // the configuration given
 | // the configuration given
 | ||||||
| func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||||
| 	var cipher *cookie.Cipher | 	var cipher *cookie.Cipher | ||||||
| 	if len(cookieOpts.CookieSecret) > 0 { | 	if opts.EnableCipher { | ||||||
| 		var err error | 		var err error | ||||||
| 		cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | 		cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
|  |  | ||||||
|  | @ -12,6 +12,8 @@ import ( | ||||||
| func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||||
| 	switch opts.Type { | 	switch opts.Type { | ||||||
| 	case options.CookieSessionStoreType: | 	case options.CookieSessionStoreType: | ||||||
|  | 		// Ensure EnableCipher is propogated from the parent option
 | ||||||
|  | 		opts.CookieStoreOptions.EnableCipher = opts.EnableCipher | ||||||
| 		return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) | 		return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) | ||||||
| 	default: | 	default: | ||||||
| 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | ||||||
|  |  | ||||||
|  | @ -181,12 +181,13 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 			SessionStoreInterfaceTests() | 			SessionStoreInterfaceTests() | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		Context("with a cookie-secret set", func() { | 		Context("with encryption enabled", func() { | ||||||
| 			BeforeEach(func() { | 			BeforeEach(func() { | ||||||
| 				secret := make([]byte, 32) | 				secret := make([]byte, 32) | ||||||
| 				_, err := rand.Read(secret) | 				_, err := rand.Read(secret) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
| 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | ||||||
|  | 				opts.EnableCipher = true | ||||||
| 
 | 
 | ||||||
| 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||||
| 				Expect(err).ToNot(HaveOccurred()) | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | @ -194,6 +195,19 @@ var _ = Describe("NewSessionStore", func() { | ||||||
| 
 | 
 | ||||||
| 			SessionStoreInterfaceTests() | 			SessionStoreInterfaceTests() | ||||||
| 		}) | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with encryption enabled, but no secret", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				opts.EnableCipher = true | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("returns an error", func() { | ||||||
|  | 				ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||||
|  | 				Expect(err).To(HaveOccurred()) | ||||||
|  | 				Expect(err.Error()).To(Equal("unable to create cipher: crypto/aes: invalid key size 0")) | ||||||
|  | 				Expect(ss).To(BeNil()) | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	BeforeEach(func() { | 	BeforeEach(func() { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue