Merge pull request #147 from pusher/session-store
Add initial session-store interface and implementation
This commit is contained in:
		
						commit
						17e97ab884
					
				|  | @ -10,7 +10,11 @@ | |||
| 
 | ||||
| ## Changes since v3.2.0 | ||||
| 
 | ||||
| - [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings)  | ||||
| - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) | ||||
|   - Allows for multiple different session storage implementations including client and server side | ||||
|   - Adds tests suite for interface to ensure consistency across implementations | ||||
|   - Refactor some configuration options (around cookies) into packages | ||||
| - [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings) | ||||
| - [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath) | ||||
| - [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes) | ||||
| - [#142](https://github.com/pusher/oauth2_proxy/pull/142) ARM Docker USER fix (@kskewes) | ||||
|  |  | |||
|  | @ -57,6 +57,20 @@ | |||
|   pruneopts = "" | ||||
|   revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:b3c5b95e56c06f5aa72cb2500e6ee5f44fcd122872d4fec2023a488e561218bc" | ||||
|   name = "github.com/hpcloud/tail" | ||||
|   packages = [ | ||||
|     ".", | ||||
|     "ratelimiter", | ||||
|     "util", | ||||
|     "watch", | ||||
|     "winfile", | ||||
|   ] | ||||
|   pruneopts = "" | ||||
|   revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5" | ||||
|   version = "v1.0.0" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" | ||||
|   name = "github.com/mbland/hmacauth" | ||||
|  | @ -73,6 +87,54 @@ | |||
|   pruneopts = "" | ||||
|   revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:a3735b0978a8b53fc2ac97a6f46ca1189f0712a00df86d6ec4cf26c1a25e6d77" | ||||
|   name = "github.com/onsi/ginkgo" | ||||
|   packages = [ | ||||
|     ".", | ||||
|     "config", | ||||
|     "internal/codelocation", | ||||
|     "internal/containernode", | ||||
|     "internal/failer", | ||||
|     "internal/leafnodes", | ||||
|     "internal/remote", | ||||
|     "internal/spec", | ||||
|     "internal/spec_iterator", | ||||
|     "internal/specrunner", | ||||
|     "internal/suite", | ||||
|     "internal/testingtproxy", | ||||
|     "internal/writer", | ||||
|     "reporters", | ||||
|     "reporters/stenographer", | ||||
|     "reporters/stenographer/support/go-colorable", | ||||
|     "reporters/stenographer/support/go-isatty", | ||||
|     "types", | ||||
|   ] | ||||
|   pruneopts = "" | ||||
|   revision = "eea6ad008b96acdaa524f5b409513bf062b500ad" | ||||
|   version = "v1.8.0" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:dbafce2fddb1ca331646fe2ac9c9413980368b19a60a4406a6e5861680bd73be" | ||||
|   name = "github.com/onsi/gomega" | ||||
|   packages = [ | ||||
|     ".", | ||||
|     "format", | ||||
|     "internal/assertion", | ||||
|     "internal/asyncassertion", | ||||
|     "internal/oraclematcher", | ||||
|     "internal/testingtsupport", | ||||
|     "matchers", | ||||
|     "matchers/support/goraph/bipartitegraph", | ||||
|     "matchers/support/goraph/edge", | ||||
|     "matchers/support/goraph/node", | ||||
|     "matchers/support/goraph/util", | ||||
|     "types", | ||||
|   ] | ||||
|   pruneopts = "" | ||||
|   revision = "90e289841c1ed79b7a598a7cd9959750cb5e89e2" | ||||
|   version = "v1.5.0" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" | ||||
|   name = "github.com/pmezard/go-difflib" | ||||
|  | @ -131,6 +193,9 @@ | |||
|   packages = [ | ||||
|     "context", | ||||
|     "context/ctxhttp", | ||||
|     "html", | ||||
|     "html/atom", | ||||
|     "html/charset", | ||||
|     "websocket", | ||||
|   ] | ||||
|   pruneopts = "" | ||||
|  | @ -150,6 +215,42 @@ | |||
|   pruneopts = "" | ||||
|   revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" | ||||
| 
 | ||||
| [[projects]] | ||||
|   branch = "master" | ||||
|   digest = "1:67a6e61e60283fd7dce50eba228080bff8805d9d69b2f121d7ec2260d120c4a8" | ||||
|   name = "golang.org/x/sys" | ||||
|   packages = ["unix"] | ||||
|   pruneopts = "" | ||||
|   revision = "ca7f33d4116e3a1f9425755d6a44e7ed9b4c97df" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" | ||||
|   name = "golang.org/x/text" | ||||
|   packages = [ | ||||
|     "encoding", | ||||
|     "encoding/charmap", | ||||
|     "encoding/htmlindex", | ||||
|     "encoding/internal", | ||||
|     "encoding/internal/identifier", | ||||
|     "encoding/japanese", | ||||
|     "encoding/korean", | ||||
|     "encoding/simplifiedchinese", | ||||
|     "encoding/traditionalchinese", | ||||
|     "encoding/unicode", | ||||
|     "internal/gen", | ||||
|     "internal/language", | ||||
|     "internal/language/compact", | ||||
|     "internal/tag", | ||||
|     "internal/utf8internal", | ||||
|     "language", | ||||
|     "runes", | ||||
|     "transform", | ||||
|     "unicode/cldr", | ||||
|   ] | ||||
|   pruneopts = "" | ||||
|   revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" | ||||
|   version = "v0.3.2" | ||||
| 
 | ||||
| [[projects]] | ||||
|   branch = "master" | ||||
|   digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" | ||||
|  | @ -182,6 +283,15 @@ | |||
|   revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" | ||||
|   version = "v1.0.0" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:eb53021a8aa3f599d29c7102e65026242bdedce998a54837dc67f14b6a97c5fd" | ||||
|   name = "gopkg.in/fsnotify.v1" | ||||
|   packages = ["."] | ||||
|   pruneopts = "" | ||||
|   revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" | ||||
|   source = "https://github.com/fsnotify/fsnotify.git" | ||||
|   version = "v1.4.7" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" | ||||
|   name = "gopkg.in/fsnotify/fsnotify.v1" | ||||
|  | @ -210,6 +320,22 @@ | |||
|   revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" | ||||
|   version = "v2.1.3" | ||||
| 
 | ||||
| [[projects]] | ||||
|   branch = "v1" | ||||
|   digest = "1:a96d16bd088460f2e0685d46c39bcf1208ba46e0a977be2df49864ec7da447dd" | ||||
|   name = "gopkg.in/tomb.v1" | ||||
|   packages = ["."] | ||||
|   pruneopts = "" | ||||
|   revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8" | ||||
| 
 | ||||
| [[projects]] | ||||
|   digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" | ||||
|   name = "gopkg.in/yaml.v2" | ||||
|   packages = ["."] | ||||
|   pruneopts = "" | ||||
|   revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" | ||||
|   version = "v2.2.2" | ||||
| 
 | ||||
| [solve-meta] | ||||
|   analyzer-name = "dep" | ||||
|   analyzer-version = 1 | ||||
|  | @ -220,6 +346,8 @@ | |||
|     "github.com/dgrijalva/jwt-go", | ||||
|     "github.com/mbland/hmacauth", | ||||
|     "github.com/mreiferson/go-options", | ||||
|     "github.com/onsi/ginkgo", | ||||
|     "github.com/onsi/gomega", | ||||
|     "github.com/stretchr/testify/assert", | ||||
|     "github.com/stretchr/testify/require", | ||||
|     "github.com/yhat/wsutil", | ||||
|  | @ -231,6 +359,7 @@ | |||
|     "google.golang.org/api/googleapi", | ||||
|     "gopkg.in/fsnotify/fsnotify.v1", | ||||
|     "gopkg.in/natefinch/lumberjack.v2", | ||||
|     "gopkg.in/square/go-jose.v2", | ||||
|   ] | ||||
|   solver-name = "gps-cdcl" | ||||
|   solver-version = 1 | ||||
|  |  | |||
|  | @ -35,6 +35,10 @@ | |||
|   name = "gopkg.in/fsnotify/fsnotify.v1" | ||||
|   version = "~1.2.0" | ||||
| 
 | ||||
| [[override]] | ||||
|   name = "gopkg.in/fsnotify.v1" | ||||
|   source = "https://github.com/fsnotify/fsnotify.git" | ||||
| 
 | ||||
| [[constraint]] | ||||
|   branch = "master" | ||||
|   name = "golang.org/x/crypto" | ||||
|  |  | |||
							
								
								
									
										1
									
								
								Makefile
								
								
								
								
							
							
						
						
									
										1
									
								
								Makefile
								
								
								
								
							|  | @ -33,6 +33,7 @@ lint: $(GOMETALINTER) | |||
| 		--enable=deadcode \
 | ||||
| 		--enable=gofmt \
 | ||||
| 		--enable=goimports \
 | ||||
| 		--deadline=120s \
 | ||||
| 		--tests ./... | ||||
| 
 | ||||
| .PHONY: dep | ||||
|  |  | |||
|  | @ -1,7 +1,8 @@ | |||
| --- | ||||
| layout: default | ||||
| title: Configuration | ||||
| permalink: /configuration | ||||
| permalink: /docs/configuration | ||||
| has_children: true | ||||
| nav_order: 3 | ||||
| --- | ||||
| 
 | ||||
|  | @ -78,6 +79,7 @@ Usage of oauth2_proxy: | |||
|   -request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below) | ||||
|   -resource string: The resource that is protected (Azure AD only) | ||||
|   -scope string: OAuth scope specification | ||||
|   -session-store-type: Session data storage backend (default: cookie) | ||||
|   -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) | ||||
|   -set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode) | ||||
|   -signature-key string: GAP-Signature request signature key (algorithm:secretkey) | ||||
|  | @ -0,0 +1,34 @@ | |||
| --- | ||||
| layout: default | ||||
| title: Sessions | ||||
| permalink: /configuration | ||||
| parent: Configuration | ||||
| nav_order: 3 | ||||
| --- | ||||
| 
 | ||||
| ## Sessions | ||||
| 
 | ||||
| Sessions allow a user's authentication to be tracked between multiple HTTP | ||||
| requests to a service. | ||||
| 
 | ||||
| The OAuth2 Proxy uses a Cookie to track user sessions and will store the session | ||||
| data in one of the available session storage backends. | ||||
| 
 | ||||
| At present the available backends are (as passed to `--session-store-type`): | ||||
| - [cookie](cookie-storage) (deafult) | ||||
| 
 | ||||
| ### Cookie Storage | ||||
| 
 | ||||
| The Cookie storage backend is the default backend implementation and has | ||||
| been used in the OAuth2 Proxy historically. | ||||
| 
 | ||||
| With the Cookie storage backend, all session information is stored in client | ||||
| side cookies and transferred with each and every request. | ||||
| 
 | ||||
| The following should be known when using this implementation: | ||||
| - Since all state is stored client side, this storage backend means that the OAuth2 Proxy is completely stateless | ||||
| - Cookies are signed server side to prevent modification client-side | ||||
| - It is recommended to set a `cookie-secret` which will ensure data is encrypted within the cookie data. | ||||
| - Since multiple requests can be made concurrently to the OAuth2 Proxy, this session implementation | ||||
| cannot lock sessions and while updating and refreshing sessions, there can be conflicts which force | ||||
| users to re-authenticate | ||||
|  | @ -15,14 +15,27 @@ type EnvOptions map[string]interface{} | |||
| // Fields in the options struct must have an `env` and `cfg` tag to be read
 | ||||
| // from the environment
 | ||||
| func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { | ||||
| 	val := reflect.ValueOf(options).Elem() | ||||
| 	typ := val.Type() | ||||
| 	val := reflect.ValueOf(options) | ||||
| 	var typ reflect.Type | ||||
| 	if val.Kind() == reflect.Ptr { | ||||
| 		typ = val.Elem().Type() | ||||
| 	} else { | ||||
| 		typ = val.Type() | ||||
| 	} | ||||
| 
 | ||||
| 	for i := 0; i < typ.NumField(); i++ { | ||||
| 		// pull out the struct tags:
 | ||||
| 		//    flag - the name of the command line flag
 | ||||
| 		//    deprecated - (optional) the name of the deprecated command line flag
 | ||||
| 		//    cfg - (optional, defaults to underscored flag) the name of the config file option
 | ||||
| 		field := typ.Field(i) | ||||
| 		fieldV := reflect.Indirect(val).Field(i) | ||||
| 
 | ||||
| 		if field.Type.Kind() == reflect.Struct && field.Anonymous { | ||||
| 			cfg.LoadEnvForStruct(fieldV.Interface()) | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		flagName := field.Tag.Get("flag") | ||||
| 		envName := field.Tag.Get("env") | ||||
| 		cfgName := field.Tag.Get("cfg") | ||||
|  |  | |||
|  | @ -1,26 +1,46 @@ | |||
| package main | ||||
| package main_test | ||||
| 
 | ||||
| import ( | ||||
| 	"os" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	proxy "github.com/pusher/oauth2_proxy" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| type envTest struct { | ||||
| 	testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` | ||||
| type EnvTest struct { | ||||
| 	TestField string `cfg:"target_field" env:"TEST_ENV_FIELD"` | ||||
| 	EnvTestEmbed | ||||
| } | ||||
| 
 | ||||
| type EnvTestEmbed struct { | ||||
| 	TestFieldEmbed string `cfg:"target_field_embed" env:"TEST_ENV_FIELD_EMBED"` | ||||
| } | ||||
| 
 | ||||
| func TestLoadEnvForStruct(t *testing.T) { | ||||
| 
 | ||||
| 	cfg := make(EnvOptions) | ||||
| 	cfg.LoadEnvForStruct(&envTest{}) | ||||
| 	cfg := make(proxy.EnvOptions) | ||||
| 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||
| 
 | ||||
| 	_, ok := cfg["target_field"] | ||||
| 	assert.Equal(t, ok, false) | ||||
| 
 | ||||
| 	os.Setenv("TEST_ENV_FIELD", "1234abcd") | ||||
| 	cfg.LoadEnvForStruct(&envTest{}) | ||||
| 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||
| 	v := cfg["target_field"] | ||||
| 	assert.Equal(t, v, "1234abcd") | ||||
| } | ||||
| 
 | ||||
| func TestLoadEnvForStructWithEmbeddedFields(t *testing.T) { | ||||
| 
 | ||||
| 	cfg := make(proxy.EnvOptions) | ||||
| 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||
| 
 | ||||
| 	_, ok := cfg["target_field_embed"] | ||||
| 	assert.Equal(t, ok, false) | ||||
| 
 | ||||
| 	os.Setenv("TEST_ENV_FIELD_EMBED", "1234abcd") | ||||
| 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||
| 	v := cfg["target_field_embed"] | ||||
| 	assert.Equal(t, v, "1234abcd") | ||||
| } | ||||
|  |  | |||
							
								
								
									
										2
									
								
								main.go
								
								
								
								
							
							
						
						
									
										2
									
								
								main.go
								
								
								
								
							|  | @ -75,6 +75,8 @@ func main() { | |||
| 	flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") | ||||
| 	flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") | ||||
| 
 | ||||
| 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | ||||
| 
 | ||||
| 	flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") | ||||
| 	flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation") | ||||
| 	flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files") | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ import ( | |||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/providers" | ||||
| 	"github.com/yhat/wsutil" | ||||
| ) | ||||
|  | @ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { | |||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | ||||
| } | ||||
| 
 | ||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { | ||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		return nil, errors.New("missing code") | ||||
| 	} | ||||
|  | @ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, | |||
| } | ||||
| 
 | ||||
| // LoadCookiedSession reads the user's authentication details from the request
 | ||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { | ||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { | ||||
| 	var age time.Duration | ||||
| 	c, err := loadCookie(req, p.CookieName) | ||||
| 	if err != nil { | ||||
|  | @ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt | |||
| } | ||||
| 
 | ||||
| // SaveSession creates a new session cookie value and sets this on the response
 | ||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { | ||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
|  | @ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | |||
| 
 | ||||
| 	user, ok := p.ManualSignIn(rw, req) | ||||
| 	if ok { | ||||
| 		session := &providers.SessionState{User: user} | ||||
| 		session := &sessions.SessionState{User: user} | ||||
| 		p.SaveSession(rw, req, session) | ||||
| 		http.Redirect(rw, req, redirect, 302) | ||||
| 	} else { | ||||
|  | @ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | |||
| 
 | ||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | ||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { | ||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { | ||||
| 	if p.HtpasswdFile == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | @ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, | |||
| 	} | ||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||
| 		return &providers.SessionState{User: pair[0]}, nil | ||||
| 		return &sessions.SessionState{User: pair[0]}, nil | ||||
| 	} | ||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||
| 	return nil, nil | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ import ( | |||
| 
 | ||||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/providers" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/require" | ||||
|  | @ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { | ||||
| func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { | ||||
| 	return tp.EmailAddress, nil | ||||
| } | ||||
| 
 | ||||
| func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { | ||||
| func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { | ||||
| 	return tp.ValidToken | ||||
| } | ||||
| 
 | ||||
|  | @ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook | |||
| 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { | ||||
| func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { | ||||
| 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
|  | @ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time | |||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { | ||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { | ||||
| 	return p.proxy.LoadCookiedSession(p.req) | ||||
| } | ||||
| 
 | ||||
| func TestLoadCookiedSession(t *testing.T) { | ||||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 
 | ||||
| 	startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | ||||
| 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, time.Now()) | ||||
| 
 | ||||
| 	session, _, err := pcTest.LoadCookiedSession() | ||||
|  | @ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { | |||
| 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||
| 
 | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, reference) | ||||
| 
 | ||||
| 	session, age, err := pcTest.LoadCookiedSession() | ||||
|  | @ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { | |||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, reference) | ||||
| 
 | ||||
| 	session, _, err := pcTest.LoadCookiedSession() | ||||
|  | @ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | |||
| 	pcTest := NewProcessCookieTestWithDefaults() | ||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	pcTest.SaveSession(startSession, reference) | ||||
| 
 | ||||
| 	pcTest.proxy.CookieRefresh = time.Hour | ||||
|  | @ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { | |||
| 
 | ||||
| func TestAuthOnlyEndpointAccepted(t *testing.T) { | ||||
| 	test := NewAuthOnlyEndpointTest() | ||||
| 	startSession := &providers.SessionState{ | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	test.SaveSession(startSession, time.Now()) | ||||
| 
 | ||||
|  | @ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | |||
| 	test := NewAuthOnlyEndpointTest() | ||||
| 	test.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||
| 	startSession := &providers.SessionState{ | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	test.SaveSession(startSession, reference) | ||||
| 
 | ||||
|  | @ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | |||
| 
 | ||||
| func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||
| 	test := NewAuthOnlyEndpointTest() | ||||
| 	startSession := &providers.SessionState{ | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||
| 	test.SaveSession(startSession, time.Now()) | ||||
| 	test.validateUser = false | ||||
|  | @ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | |||
| 	pcTest.req, _ = http.NewRequest("GET", | ||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||
| 
 | ||||
| 	startSession := &providers.SessionState{ | ||||
| 	startSession := &sessions.SessionState{ | ||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | ||||
| 	pcTest.SaveSession(startSession, time.Now()) | ||||
| 
 | ||||
|  | @ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { | |||
| 	req := httptest.NewRequest(method, "/foo/bar", bodyBuf) | ||||
| 	req.Header = st.header | ||||
| 
 | ||||
| 	state := &providers.SessionState{ | ||||
| 	state := &sessions.SessionState{ | ||||
| 		Email: "mbland@acm.org", AccessToken: "my_access_token"} | ||||
| 	value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) | ||||
| 	if err != nil { | ||||
|  |  | |||
							
								
								
									
										36
									
								
								options.go
								
								
								
								
							
							
						
						
									
										36
									
								
								options.go
								
								
								
								
							|  | @ -18,6 +18,7 @@ import ( | |||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/mbland/hmacauth" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| 	"github.com/pusher/oauth2_proxy/providers" | ||||
| 	"gopkg.in/natefinch/lumberjack.v2" | ||||
| ) | ||||
|  | @ -49,14 +50,11 @@ type Options struct { | |||
| 	CustomTemplatesDir       string   `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"` | ||||
| 	Footer                   string   `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"` | ||||
| 
 | ||||
| 	CookieName     string        `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` | ||||
| 	CookieSecret   string        `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` | ||||
| 	CookieDomain   string        `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` | ||||
| 	CookiePath     string        `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` | ||||
| 	CookieExpire   time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` | ||||
| 	CookieRefresh  time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` | ||||
| 	CookieSecure   bool          `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` | ||||
| 	CookieHTTPOnly bool          `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` | ||||
| 	// Embed CookieOptions
 | ||||
| 	options.CookieOptions | ||||
| 
 | ||||
| 	// Embed SessionOptions
 | ||||
| 	options.SessionOptions | ||||
| 
 | ||||
| 	Upstreams             []string      `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"` | ||||
| 	SkipAuthRegex         []string      `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"` | ||||
|  | @ -126,16 +124,18 @@ type SignatureData struct { | |||
| // NewOptions constructs a new Options with defaulted values
 | ||||
| func NewOptions() *Options { | ||||
| 	return &Options{ | ||||
| 		ProxyPrefix:           "/oauth2", | ||||
| 		ProxyWebSockets:       true, | ||||
| 		HTTPAddress:           "127.0.0.1:4180", | ||||
| 		HTTPSAddress:          ":443", | ||||
| 		DisplayHtpasswdForm:   true, | ||||
| 		CookieName:            "_oauth2_proxy", | ||||
| 		CookieSecure:          true, | ||||
| 		CookieHTTPOnly:        true, | ||||
| 		CookieExpire:          time.Duration(168) * time.Hour, | ||||
| 		CookieRefresh:         time.Duration(0), | ||||
| 		ProxyPrefix:         "/oauth2", | ||||
| 		ProxyWebSockets:     true, | ||||
| 		HTTPAddress:         "127.0.0.1:4180", | ||||
| 		HTTPSAddress:        ":443", | ||||
| 		DisplayHtpasswdForm: true, | ||||
| 		CookieOptions: options.CookieOptions{ | ||||
| 			CookieName:     "_oauth2_proxy", | ||||
| 			CookieSecure:   true, | ||||
| 			CookieHTTPOnly: true, | ||||
| 			CookieExpire:   time.Duration(168) * time.Hour, | ||||
| 			CookieRefresh:  time.Duration(0), | ||||
| 		}, | ||||
| 		SetXAuthRequest:       false, | ||||
| 		SkipAuthPreflight:     false, | ||||
| 		PassBasicAuth:         true, | ||||
|  |  | |||
|  | @ -0,0 +1,15 @@ | |||
| package options | ||||
| 
 | ||||
| import "time" | ||||
| 
 | ||||
| // CookieOptions contains configuration options relating to Cookie configuration
 | ||||
| type CookieOptions struct { | ||||
| 	CookieName     string        `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` | ||||
| 	CookieSecret   string        `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` | ||||
| 	CookieDomain   string        `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` | ||||
| 	CookiePath     string        `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` | ||||
| 	CookieExpire   time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` | ||||
| 	CookieRefresh  time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` | ||||
| 	CookieSecure   bool          `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` | ||||
| 	CookieHTTPOnly bool          `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` | ||||
| } | ||||
|  | @ -0,0 +1,14 @@ | |||
| package options | ||||
| 
 | ||||
| // SessionOptions contains configuration options for the SessionStore providers.
 | ||||
| type SessionOptions struct { | ||||
| 	Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | ||||
| 	CookieStoreOptions | ||||
| } | ||||
| 
 | ||||
| // CookieSessionStoreType is used to indicate the CookieSessionStore should be
 | ||||
| // used for storing sessions.
 | ||||
| var CookieSessionStoreType = "cookie" | ||||
| 
 | ||||
| // CookieStoreOptions contains configuration options for the CookieSessionStore.
 | ||||
| type CookieStoreOptions struct{} | ||||
|  | @ -0,0 +1,12 @@ | |||
| package sessions | ||||
| 
 | ||||
| import ( | ||||
| 	"net/http" | ||||
| ) | ||||
| 
 | ||||
| // SessionStore is an interface to storing user sessions in the proxy
 | ||||
| type SessionStore interface { | ||||
| 	Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error | ||||
| 	Load(req *http.Request) (*SessionState, error) | ||||
| 	Clear(rw http.ResponseWriter, req *http.Request) error | ||||
| } | ||||
|  | @ -1,4 +1,4 @@ | |||
| package providers | ||||
| package sessions | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/json" | ||||
|  | @ -1,4 +1,4 @@ | |||
| package providers | ||||
| package sessions_test | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := cookie.NewCipher([]byte(altSecret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &SessionState{ | ||||
| 	s := &sessions.SessionState{ | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		IDToken:      "rawtoken1234", | ||||
|  | @ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@domain.com", ss.User) | ||||
|  | @ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
| 
 | ||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||
| 	ss, err = DecodeSessionState(encoded, c2) | ||||
| 	ss, err = sessions.DecodeSessionState(encoded, c2) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, "user@domain.com", ss.User) | ||||
|  | @ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 	c2, err := cookie.NewCipher([]byte(altSecret)) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	s := &SessionState{ | ||||
| 	s := &sessions.SessionState{ | ||||
| 		User:         "just-user", | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
|  | @ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.User, ss.User) | ||||
|  | @ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
| 
 | ||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||
| 	ss, err = DecodeSessionState(encoded, c2) | ||||
| 	ss, err = sessions.DecodeSessionState(encoded, c2) | ||||
| 	t.Logf("%#v", ss) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.NotEqual(t, s.User, ss.User) | ||||
|  | @ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||
| 	s := &SessionState{ | ||||
| 	s := &sessions.SessionState{ | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
|  | @ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	// only email should have been serialized
 | ||||
| 	ss, err := DecodeSessionState(encoded, nil) | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@domain.com", ss.User) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
|  | @ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||
| 	s := &SessionState{ | ||||
| 	s := &sessions.SessionState{ | ||||
| 		User:         "just-user", | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
|  | @ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | |||
| 	assert.Equal(t, nil, err) | ||||
| 
 | ||||
| 	// only email should have been serialized
 | ||||
| 	ss, err := DecodeSessionState(encoded, nil) | ||||
| 	ss, err := sessions.DecodeSessionState(encoded, nil) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, s.User, ss.User) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
|  | @ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | |||
| } | ||||
| 
 | ||||
| func TestExpired(t *testing.T) { | ||||
| 	s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} | ||||
| 	s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} | ||||
| 	assert.Equal(t, true, s.IsExpired()) | ||||
| 
 | ||||
| 	s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} | ||||
| 	s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} | ||||
| 	assert.Equal(t, false, s.IsExpired()) | ||||
| 
 | ||||
| 	s = &SessionState{} | ||||
| 	s = &sessions.SessionState{} | ||||
| 	assert.Equal(t, false, s.IsExpired()) | ||||
| } | ||||
| 
 | ||||
| type testCase struct { | ||||
| 	SessionState | ||||
| 	sessions.SessionState | ||||
| 	Encoded string | ||||
| 	Cipher  *cookie.Cipher | ||||
| 	Error   bool | ||||
|  | @ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) { | |||
| 
 | ||||
| 	testCases := []testCase{ | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email:        "user@domain.com", | ||||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
|  | @ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) { | |||
| 
 | ||||
| 	for i, tc := range testCases { | ||||
| 		encoded, err := tc.EncodeSessionState(tc.Cipher) | ||||
| 		t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | ||||
| 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | ||||
| 		if tc.Error { | ||||
| 			assert.Error(t, err) | ||||
| 			assert.Empty(t, encoded) | ||||
|  | @ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // TestDecodeSessionState tests DecodeSessionState with the test vector
 | ||||
| // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
 | ||||
| func TestDecodeSessionState(t *testing.T) { | ||||
| 	e := time.Now().Add(time.Duration(1) * time.Hour) | ||||
| 	eJSON, _ := e.MarshalJSON() | ||||
|  | @ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 
 | ||||
| 	testCases := []testCase{ | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "user@domain.com", | ||||
| 			}, | ||||
| 			Encoded: `{"Email":"user@domain.com"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				User: "just-user", | ||||
| 			}, | ||||
| 			Encoded: `{"User":"just-user"}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email:        "user@domain.com", | ||||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
|  | @ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 			Cipher:  c, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email: "user@domain.com", | ||||
| 				User:  "just-user", | ||||
| 			}, | ||||
|  | @ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 			Error:   true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				User:  "just-user", | ||||
| 				Email: "user@domain.com", | ||||
| 			}, | ||||
|  | @ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 			Error:   true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email:        "user@domain.com", | ||||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
|  | @ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 			Cipher:  c, | ||||
| 		}, | ||||
| 		{ | ||||
| 			SessionState: SessionState{ | ||||
| 			SessionState: sessions.SessionState{ | ||||
| 				Email:        "user@domain.com", | ||||
| 				User:         "just-user", | ||||
| 				AccessToken:  "token1234", | ||||
|  | @ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) { | |||
| 	} | ||||
| 
 | ||||
| 	for i, tc := range testCases { | ||||
| 		ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) | ||||
| 		t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | ||||
| 		ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) | ||||
| 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | ||||
| 		if tc.Error { | ||||
| 			assert.Error(t, err) | ||||
| 			assert.Nil(t, ss) | ||||
|  | @ -0,0 +1,34 @@ | |||
| package cookies | ||||
| 
 | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| ) | ||||
| 
 | ||||
| // MakeCookie constructs a cookie from the given parameters,
 | ||||
| // discovering the domain from the request if not specified.
 | ||||
| func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time) *http.Cookie { | ||||
| 	if domain != "" { | ||||
| 		host := req.Host | ||||
| 		if h, _, err := net.SplitHostPort(host); err == nil { | ||||
| 			host = h | ||||
| 		} | ||||
| 		if !strings.HasSuffix(host, domain) { | ||||
| 			logger.Printf("Warning: request host is %q but using configured cookie domain of %q", host, domain) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return &http.Cookie{ | ||||
| 		Name:     name, | ||||
| 		Value:    value, | ||||
| 		Path:     path, | ||||
| 		Domain:   domain, | ||||
| 		HttpOnly: httpOnly, | ||||
| 		Secure:   secure, | ||||
| 		Expires:  now.Add(expiration), | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,232 @@ | |||
| package cookie | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/cookies" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/utils" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// Cookies are limited to 4kb including the length of the cookie name,
 | ||||
| 	// the cookie name can be up to 256 bytes
 | ||||
| 	maxCookieLength = 3840 | ||||
| ) | ||||
| 
 | ||||
| // Ensure CookieSessionStore implements the interface
 | ||||
| var _ sessions.SessionStore = &SessionStore{} | ||||
| 
 | ||||
| // SessionStore is an implementation of the sessions.SessionStore
 | ||||
| // interface that stores sessions in client side cookies
 | ||||
| type SessionStore struct { | ||||
| 	CookieCipher   *cookie.Cipher | ||||
| 	CookieDomain   string | ||||
| 	CookieExpire   time.Duration | ||||
| 	CookieHTTPOnly bool | ||||
| 	CookieName     string | ||||
| 	CookiePath     string | ||||
| 	CookieSecret   string | ||||
| 	CookieSecure   bool | ||||
| } | ||||
| 
 | ||||
| // Save takes a sessions.SessionState and stores the information from it
 | ||||
| // within Cookies set on the HTTP response writer
 | ||||
| func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||
| 	value, err := utils.CookieForSession(ss, s.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	s.setSessionCookie(rw, req, value) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // Load reads sessions.SessionState information from Cookies within the
 | ||||
| // HTTP request object
 | ||||
| func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { | ||||
| 	c, err := loadCookie(req, s.CookieName) | ||||
| 	if err != nil { | ||||
| 		// always http.ErrNoCookie
 | ||||
| 		return nil, fmt.Errorf("Cookie %q not present", s.CookieName) | ||||
| 	} | ||||
| 	val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) | ||||
| 	if !ok { | ||||
| 		return nil, errors.New("Cookie Signature not valid") | ||||
| 	} | ||||
| 
 | ||||
| 	session, err := utils.SessionFromCookie(val, s.CookieCipher) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // Clear clears any saved session information by writing a cookie to
 | ||||
| // clear the session
 | ||||
| func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	var cookies []*http.Cookie | ||||
| 
 | ||||
| 	// matches CookieName, CookieName_<number>
 | ||||
| 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName)) | ||||
| 
 | ||||
| 	for _, c := range req.Cookies() { | ||||
| 		if cookieNameRegex.MatchString(c.Name) { | ||||
| 			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) | ||||
| 
 | ||||
| 			http.SetCookie(rw, clearCookie) | ||||
| 			cookies = append(cookies, clearCookie) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // setSessionCookie adds the user's session cookie to the response
 | ||||
| func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||
| 	for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { | ||||
| 		http.SetCookie(rw, c) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // makeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||
| // authentication details
 | ||||
| func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { | ||||
| 	if value != "" { | ||||
| 		value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) | ||||
| 	} | ||||
| 	c := s.makeCookie(req, s.CookieName, value, expiration) | ||||
| 	if len(c.Value) > 4096-len(s.CookieName) { | ||||
| 		return splitCookie(c) | ||||
| 	} | ||||
| 	return []*http.Cookie{c} | ||||
| } | ||||
| 
 | ||||
| func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { | ||||
| 	return cookies.MakeCookie( | ||||
| 		req, | ||||
| 		name, | ||||
| 		value, | ||||
| 		s.CookiePath, | ||||
| 		s.CookieDomain, | ||||
| 		s.CookieHTTPOnly, | ||||
| 		s.CookieSecure, | ||||
| 		expiration, | ||||
| 		time.Now(), | ||||
| 	) | ||||
| } | ||||
| 
 | ||||
| // NewCookieSessionStore initialises a new instance of the SessionStore from
 | ||||
| // the configuration given
 | ||||
| func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||
| 	var cipher *cookie.Cipher | ||||
| 	if len(cookieOpts.CookieSecret) > 0 { | ||||
| 		var err error | ||||
| 		cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | ||||
| 		if err != nil { | ||||
| 			return nil, fmt.Errorf("unable to create cipher: %v", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return &SessionStore{ | ||||
| 		CookieCipher:   cipher, | ||||
| 		CookieDomain:   cookieOpts.CookieDomain, | ||||
| 		CookieExpire:   cookieOpts.CookieExpire, | ||||
| 		CookieHTTPOnly: cookieOpts.CookieHTTPOnly, | ||||
| 		CookieName:     cookieOpts.CookieName, | ||||
| 		CookiePath:     cookieOpts.CookiePath, | ||||
| 		CookieSecret:   cookieOpts.CookieSecret, | ||||
| 		CookieSecure:   cookieOpts.CookieSecure, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // splitCookie reads the full cookie generated to store the session and splits
 | ||||
| // it into a slice of cookies which fit within the 4kb cookie limit indexing
 | ||||
| // the cookies from 0
 | ||||
| func splitCookie(c *http.Cookie) []*http.Cookie { | ||||
| 	if len(c.Value) < maxCookieLength { | ||||
| 		return []*http.Cookie{c} | ||||
| 	} | ||||
| 	cookies := []*http.Cookie{} | ||||
| 	valueBytes := []byte(c.Value) | ||||
| 	count := 0 | ||||
| 	for len(valueBytes) > 0 { | ||||
| 		new := copyCookie(c) | ||||
| 		new.Name = fmt.Sprintf("%s_%d", c.Name, count) | ||||
| 		count++ | ||||
| 		if len(valueBytes) < maxCookieLength { | ||||
| 			new.Value = string(valueBytes) | ||||
| 			valueBytes = []byte{} | ||||
| 		} else { | ||||
| 			newValue := valueBytes[:maxCookieLength] | ||||
| 			valueBytes = valueBytes[maxCookieLength:] | ||||
| 			new.Value = string(newValue) | ||||
| 		} | ||||
| 		cookies = append(cookies, new) | ||||
| 	} | ||||
| 	return cookies | ||||
| } | ||||
| 
 | ||||
| // loadCookie retreieves the sessions state cookie from the http request.
 | ||||
| // If a single cookie is present this will be returned, otherwise it attempts
 | ||||
| // to reconstruct a cookie split up by splitCookie
 | ||||
| func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { | ||||
| 	c, err := req.Cookie(cookieName) | ||||
| 	if err == nil { | ||||
| 		return c, nil | ||||
| 	} | ||||
| 	cookies := []*http.Cookie{} | ||||
| 	err = nil | ||||
| 	count := 0 | ||||
| 	for err == nil { | ||||
| 		var c *http.Cookie | ||||
| 		c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) | ||||
| 		if err == nil { | ||||
| 			cookies = append(cookies, c) | ||||
| 			count++ | ||||
| 		} | ||||
| 	} | ||||
| 	if len(cookies) == 0 { | ||||
| 		return nil, fmt.Errorf("Could not find cookie %s", cookieName) | ||||
| 	} | ||||
| 	return joinCookies(cookies) | ||||
| } | ||||
| 
 | ||||
| // joinCookies takes a slice of cookies from the request and reconstructs the
 | ||||
| // full session cookie
 | ||||
| func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { | ||||
| 	if len(cookies) == 0 { | ||||
| 		return nil, fmt.Errorf("list of cookies must be > 0") | ||||
| 	} | ||||
| 	if len(cookies) == 1 { | ||||
| 		return cookies[0], nil | ||||
| 	} | ||||
| 	c := copyCookie(cookies[0]) | ||||
| 	for i := 1; i < len(cookies); i++ { | ||||
| 		c.Value += cookies[i].Value | ||||
| 	} | ||||
| 	c.Name = strings.TrimRight(c.Name, "_0") | ||||
| 	return c, nil | ||||
| } | ||||
| 
 | ||||
| func copyCookie(c *http.Cookie) *http.Cookie { | ||||
| 	return &http.Cookie{ | ||||
| 		Name:       c.Name, | ||||
| 		Value:      c.Value, | ||||
| 		Path:       c.Path, | ||||
| 		Domain:     c.Domain, | ||||
| 		Expires:    c.Expires, | ||||
| 		RawExpires: c.RawExpires, | ||||
| 		MaxAge:     c.MaxAge, | ||||
| 		Secure:     c.Secure, | ||||
| 		HttpOnly:   c.HttpOnly, | ||||
| 		Raw:        c.Raw, | ||||
| 		Unparsed:   c.Unparsed, | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,19 @@ | |||
| package sessions | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||
| ) | ||||
| 
 | ||||
| // NewSessionStore creates a SessionStore from the provided configuration
 | ||||
| func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||
| 	switch opts.Type { | ||||
| 	case options.CookieSessionStoreType: | ||||
| 		return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | ||||
| 	} | ||||
| } | ||||
|  | @ -0,0 +1,254 @@ | |||
| package sessions_test | ||||
| 
 | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||
| 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/cookies" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||
| ) | ||||
| 
 | ||||
| func TestSessionStore(t *testing.T) { | ||||
| 	RegisterFailHandler(Fail) | ||||
| 	RunSpecs(t, "SessionStore") | ||||
| } | ||||
| 
 | ||||
| var _ = Describe("NewSessionStore", func() { | ||||
| 	var opts *options.SessionOptions | ||||
| 	var cookieOpts *options.CookieOptions | ||||
| 
 | ||||
| 	var request *http.Request | ||||
| 	var response *httptest.ResponseRecorder | ||||
| 	var session *sessionsapi.SessionState | ||||
| 	var ss sessionsapi.SessionStore | ||||
| 
 | ||||
| 	CheckCookieOptions := func() { | ||||
| 		Context("the cookies returned", func() { | ||||
| 			var cookies []*http.Cookie | ||||
| 			BeforeEach(func() { | ||||
| 				cookies = response.Result().Cookies() | ||||
| 			}) | ||||
| 
 | ||||
| 			It("have the correct name set", func() { | ||||
| 				if len(cookies) == 1 { | ||||
| 					Expect(cookies[0].Name).To(Equal(cookieOpts.CookieName)) | ||||
| 				} else { | ||||
| 					for _, cookie := range cookies { | ||||
| 						Expect(cookie.Name).To(ContainSubstring(cookieOpts.CookieName)) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 			It("have the correct path set", func() { | ||||
| 				for _, cookie := range cookies { | ||||
| 					Expect(cookie.Path).To(Equal(cookieOpts.CookiePath)) | ||||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 			It("have the correct domain set", func() { | ||||
| 				for _, cookie := range cookies { | ||||
| 					Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain)) | ||||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 			It("have the correct HTTPOnly set", func() { | ||||
| 				for _, cookie := range cookies { | ||||
| 					Expect(cookie.HttpOnly).To(Equal(cookieOpts.CookieHTTPOnly)) | ||||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 			It("have the correct secure set", func() { | ||||
| 				for _, cookie := range cookies { | ||||
| 					Expect(cookie.Secure).To(Equal(cookieOpts.CookieSecure)) | ||||
| 				} | ||||
| 			}) | ||||
| 
 | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	SessionStoreInterfaceTests := func() { | ||||
| 		Context("when Save is called", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				err := ss.Save(response, request, session) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("sets a `set-cookie` header in the response", func() { | ||||
| 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | ||||
| 			}) | ||||
| 
 | ||||
| 			CheckCookieOptions() | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("when Clear is called", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				cookie := cookies.MakeCookie(request, | ||||
| 					cookieOpts.CookieName, | ||||
| 					"foo", | ||||
| 					cookieOpts.CookiePath, | ||||
| 					cookieOpts.CookieDomain, | ||||
| 					cookieOpts.CookieHTTPOnly, | ||||
| 					cookieOpts.CookieSecure, | ||||
| 					cookieOpts.CookieExpire, | ||||
| 					time.Now(), | ||||
| 				) | ||||
| 				request.AddCookie(cookie) | ||||
| 				err := ss.Clear(response, request) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("sets a `set-cookie` header in the response", func() { | ||||
| 				Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty()) | ||||
| 			}) | ||||
| 
 | ||||
| 			CheckCookieOptions() | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("when Load is called", func() { | ||||
| 			var loadedSession *sessionsapi.SessionState | ||||
| 			BeforeEach(func() { | ||||
| 				req := httptest.NewRequest("GET", "http://example.com/", nil) | ||||
| 				resp := httptest.NewRecorder() | ||||
| 				err := ss.Save(resp, req, session) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				for _, cookie := range resp.Result().Cookies() { | ||||
| 					request.AddCookie(cookie) | ||||
| 				} | ||||
| 				loadedSession, err = ss.Load(request) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("loads a session equal to the original session", func() { | ||||
| 				if cookieOpts.CookieSecret == "" { | ||||
| 					// Only Email and User stored in session when encrypted
 | ||||
| 					Expect(loadedSession.Email).To(Equal(session.Email)) | ||||
| 					Expect(loadedSession.User).To(Equal(session.User)) | ||||
| 				} else { | ||||
| 					// All fields stored in session if encrypted
 | ||||
| 
 | ||||
| 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | ||||
| 					l := *loadedSession | ||||
| 					l.ExpiresOn = time.Time{} | ||||
| 					s := *session | ||||
| 					s.ExpiresOn = time.Time{} | ||||
| 					Expect(l).To(Equal(s)) | ||||
| 
 | ||||
| 					// Compare time.Time separately
 | ||||
| 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	RunSessionTests := func() { | ||||
| 		Context("with default options", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				var err error | ||||
| 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			SessionStoreInterfaceTests() | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with non-default options", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				cookieOpts = &options.CookieOptions{ | ||||
| 					CookieName:     "_cookie_name", | ||||
| 					CookiePath:     "/path", | ||||
| 					CookieExpire:   time.Duration(72) * time.Hour, | ||||
| 					CookieRefresh:  time.Duration(3600), | ||||
| 					CookieSecure:   false, | ||||
| 					CookieHTTPOnly: false, | ||||
| 					CookieDomain:   "example.com", | ||||
| 				} | ||||
| 
 | ||||
| 				var err error | ||||
| 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			SessionStoreInterfaceTests() | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with a cookie-secret set", func() { | ||||
| 			BeforeEach(func() { | ||||
| 				secret := make([]byte, 32) | ||||
| 				_, err := rand.Read(secret) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | ||||
| 
 | ||||
| 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 			}) | ||||
| 
 | ||||
| 			SessionStoreInterfaceTests() | ||||
| 		}) | ||||
| 	} | ||||
| 
 | ||||
| 	BeforeEach(func() { | ||||
| 		ss = nil | ||||
| 		opts = &options.SessionOptions{} | ||||
| 
 | ||||
| 		// Set default options in CookieOptions
 | ||||
| 		cookieOpts = &options.CookieOptions{ | ||||
| 			CookieName:     "_oauth2_proxy", | ||||
| 			CookiePath:     "/", | ||||
| 			CookieExpire:   time.Duration(168) * time.Hour, | ||||
| 			CookieRefresh:  time.Duration(0), | ||||
| 			CookieSecure:   true, | ||||
| 			CookieHTTPOnly: true, | ||||
| 		} | ||||
| 
 | ||||
| 		session = &sessionsapi.SessionState{ | ||||
| 			AccessToken:  "AccessToken", | ||||
| 			IDToken:      "IDToken", | ||||
| 			ExpiresOn:    time.Now().Add(1 * time.Hour), | ||||
| 			RefreshToken: "RefreshToken", | ||||
| 			Email:        "john.doe@example.com", | ||||
| 			User:         "john.doe", | ||||
| 		} | ||||
| 
 | ||||
| 		request = httptest.NewRequest("GET", "http://example.com/", nil) | ||||
| 		response = httptest.NewRecorder() | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with type 'cookie'", func() { | ||||
| 		BeforeEach(func() { | ||||
| 			opts.Type = options.CookieSessionStoreType | ||||
| 		}) | ||||
| 
 | ||||
| 		It("creates a cookie.SessionStore", func() { | ||||
| 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||
| 			Expect(err).NotTo(HaveOccurred()) | ||||
| 			Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{})) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("the cookie.SessionStore", func() { | ||||
| 			RunSessionTests() | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("with an invalid type", func() { | ||||
| 		BeforeEach(func() { | ||||
| 			opts.Type = "invalid-type" | ||||
| 		}) | ||||
| 
 | ||||
| 		It("returns an error", func() { | ||||
| 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||
| 			Expect(err).To(HaveOccurred()) | ||||
| 			Expect(err.Error()).To(Equal("unknown session store type 'invalid-type'")) | ||||
| 			Expect(ss).To(BeNil()) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -0,0 +1,41 @@ | |||
| package utils | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // CookieForSession serializes a session state for storage in a cookie
 | ||||
| func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { | ||||
| 	return s.EncodeSessionState(c) | ||||
| } | ||||
| 
 | ||||
| // SessionFromCookie deserializes a session from a cookie value
 | ||||
| func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { | ||||
| 	return sessions.DecodeSessionState(v, c) | ||||
| } | ||||
| 
 | ||||
| // SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
 | ||||
| func SecretBytes(secret string) []byte { | ||||
| 	b, err := base64.URLEncoding.DecodeString(addPadding(secret)) | ||||
| 	if err == nil { | ||||
| 		return []byte(addPadding(string(b))) | ||||
| 	} | ||||
| 	return []byte(secret) | ||||
| } | ||||
| 
 | ||||
| func addPadding(secret string) string { | ||||
| 	padding := len(secret) % 4 | ||||
| 	switch padding { | ||||
| 	case 1: | ||||
| 		return secret + "===" | ||||
| 	case 2: | ||||
| 		return secret + "==" | ||||
| 	case 3: | ||||
| 		return secret + "=" | ||||
| 	default: | ||||
| 		return secret | ||||
| 	} | ||||
| } | ||||
|  | @ -9,6 +9,7 @@ import ( | |||
| 	"github.com/bitly/go-simplejson" | ||||
| 	"github.com/pusher/oauth2_proxy/api" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // AzureProvider represents an Azure based Identity Provider
 | ||||
|  | @ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { | |||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 	var email string | ||||
| 	var err error | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testAzureProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@windows.net", email) | ||||
|  | @ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testAzureProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@windows.net", email) | ||||
|  | @ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testAzureProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@windows.net", email) | ||||
|  | @ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testAzureProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, "type assertion to string failed", err.Error()) | ||||
| 	assert.Equal(t, "", email) | ||||
|  | @ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testAzureProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  | @ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testAzureProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, "type assertion to string failed", err.Error()) | ||||
| 	assert.Equal(t, "", email) | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/api" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // FacebookProvider represents an Facebook based Identity Provider
 | ||||
|  | @ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header { | |||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
|  | @ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | |||
| } | ||||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { | ||||
| func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||
| 	return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) | ||||
| } | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ import ( | |||
| 	"strings" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // GitHubProvider represents an GitHub based Identity Provider
 | ||||
|  | @ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { | |||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 
 | ||||
| 	var emails []struct { | ||||
| 		Email    string `json:"email"` | ||||
|  | @ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | |||
| } | ||||
| 
 | ||||
| // GetUserName returns the Account user name
 | ||||
| func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { | ||||
| func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { | ||||
| 	var user struct { | ||||
| 		Login string `json:"login"` | ||||
| 		Email string `json:"email"` | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testGitHubProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
|  | @ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testGitHubProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Empty(t, "", email) | ||||
|  | @ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { | |||
| 	p := testGitHubProvider(bURL.Host) | ||||
| 	p.Org = "testorg1" | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
|  | @ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { | |||
| 	// We'll trigger a request failure by using an unexpected access
 | ||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||
| 	// JSON to fail.
 | ||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  | @ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testGitHubProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  | @ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testGitHubProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetUserName(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "mbland", email) | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/api" | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // GitLabProvider represents an GitLab based Identity Provider
 | ||||
|  | @ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { | |||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 
 | ||||
| 	req, err := http.NewRequest("GET", | ||||
| 		p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testGitLabProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||
|  | @ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { | |||
| 	// We'll trigger a request failure by using an unexpected access
 | ||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||
| 	// JSON to fail.
 | ||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  | @ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testGitLabProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/logger" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"golang.org/x/oauth2" | ||||
| 	"golang.org/x/oauth2/google" | ||||
| 	admin "google.golang.org/api/admin/directory/v1" | ||||
|  | @ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||
| func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
|  | @ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err | |||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	s = &SessionState{ | ||||
| 	s = &sessions.SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||
|  | @ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { | |||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
| // RefreshToken to fetch a new ID token if required
 | ||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||
| 		return false, nil | ||||
| 	} | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct { | |||
| 	*ProviderData | ||||
| } | ||||
| 
 | ||||
| func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 	return "", errors.New("not implemented") | ||||
| } | ||||
| 
 | ||||
| // Note that we're testing the internal validateToken() used to implement
 | ||||
| // several Provider's ValidateSessionState() implementations
 | ||||
| func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { | ||||
| func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/api" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // LinkedInProvider represents an LinkedIn based Identity Provider
 | ||||
|  | @ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header { | |||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 	if s.AccessToken == "" { | ||||
| 		return "", errors.New("missing access token") | ||||
| 	} | ||||
|  | @ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | |||
| } | ||||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { | ||||
| func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||
| 	return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | ||||
| } | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"net/url" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
|  | @ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testLinkedInProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, "user@linkedin.com", email) | ||||
|  | @ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { | |||
| 	// We'll trigger a request failure by using an unexpected access
 | ||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||
| 	// JSON to fail.
 | ||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  | @ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | |||
| 	bURL, _ := url.Parse(b.URL) | ||||
| 	p := testLinkedInProvider(bURL.Host) | ||||
| 
 | ||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||
| 	email, err := p.GetEmailAddress(session) | ||||
| 	assert.NotEqual(t, nil, err) | ||||
| 	assert.Equal(t, "", email) | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"gopkg.in/square/go-jose.v2" | ||||
| ) | ||||
| 
 | ||||
|  | @ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||
| func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
|  | @ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er | |||
| 	} | ||||
| 
 | ||||
| 	// Store the data that we found in the session state
 | ||||
| 	s = &SessionState{ | ||||
| 	s = &sessions.SessionState{ | ||||
| 		AccessToken: jsonResponse.AccessToken, | ||||
| 		IDToken:     jsonResponse.IDToken, | ||||
| 		ExpiresOn:   time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||
|  |  | |||
|  | @ -5,9 +5,9 @@ import ( | |||
| 	"fmt" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"golang.org/x/oauth2" | ||||
| 
 | ||||
| 	oidc "github.com/coreos/go-oidc" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
| 
 | ||||
| // OIDCProvider represents an OIDC based Identity Provider
 | ||||
|  | @ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||
| func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| 	ctx := context.Background() | ||||
| 	c := oauth2.Config{ | ||||
| 		ClientID:     p.ClientID, | ||||
|  | @ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er | |||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
| // RefreshToken to fetch a new ID token if required
 | ||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||
| 		return false, nil | ||||
| 	} | ||||
|  | @ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | |||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { | ||||
| func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { | ||||
| 	c := oauth2.Config{ | ||||
| 		ClientID:     p.ClientID, | ||||
| 		ClientSecret: p.ClientSecret, | ||||
|  | @ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { | |||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { | ||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { | ||||
| 	rawIDToken, ok := token.Extra("id_token").(string) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("token response did not contain an id_token") | ||||
|  | @ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | |||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	return &SessionState{ | ||||
| 	return &sessions.SessionState{ | ||||
| 		AccessToken:  token.AccessToken, | ||||
| 		IDToken:      rawIDToken, | ||||
| 		RefreshToken: token.RefreshToken, | ||||
|  | @ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | |||
| } | ||||
| 
 | ||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | ||||
| func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { | ||||
| func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||
| 	ctx := context.Background() | ||||
| 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -10,10 +10,11 @@ import ( | |||
| 	"net/url" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // Redeem provides a default implementation of the OAuth2 token redemption process
 | ||||
| func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { | ||||
| func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||
| 	if code == "" { | ||||
| 		err = errors.New("missing code") | ||||
| 		return | ||||
|  | @ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er | |||
| 	} | ||||
| 	err = json.Unmarshal(body, &jsonResponse) | ||||
| 	if err == nil { | ||||
| 		s = &SessionState{ | ||||
| 		s = &sessions.SessionState{ | ||||
| 			AccessToken: jsonResponse.AccessToken, | ||||
| 		} | ||||
| 		return | ||||
|  | @ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er | |||
| 		return | ||||
| 	} | ||||
| 	if a := v.Get("access_token"); a != "" { | ||||
| 		s = &SessionState{AccessToken: a} | ||||
| 		s = &sessions.SessionState{AccessToken: a} | ||||
| 	} else { | ||||
| 		err = fmt.Errorf("no access token found %s", body) | ||||
| 	} | ||||
|  | @ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string { | |||
| } | ||||
| 
 | ||||
| // CookieForSession serializes a session state for storage in a cookie
 | ||||
| func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { | ||||
| func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { | ||||
| 	return s.EncodeSessionState(c) | ||||
| } | ||||
| 
 | ||||
| // SessionFromCookie deserializes a session from a cookie value
 | ||||
| func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { | ||||
| 	return DecodeSessionState(v, c) | ||||
| func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { | ||||
| 	return sessions.DecodeSessionState(v, c) | ||||
| } | ||||
| 
 | ||||
| // GetEmailAddress returns the Account email address
 | ||||
| func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { | ||||
| func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||
| 	return "", errors.New("not implemented") | ||||
| } | ||||
| 
 | ||||
| // GetUserName returns the Account username
 | ||||
| func (p *ProviderData) GetUserName(s *SessionState) (string, error) { | ||||
| func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { | ||||
| 	return "", errors.New("not implemented") | ||||
| } | ||||
| 
 | ||||
|  | @ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool { | |||
| } | ||||
| 
 | ||||
| // ValidateSessionState validates the AccessToken
 | ||||
| func (p *ProviderData) ValidateSessionState(s *SessionState) bool { | ||||
| func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { | ||||
| 	return validateToken(p, s.AccessToken, nil) | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded should refresh the user's session if required and
 | ||||
| // do nothing if a refresh is not required
 | ||||
| func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||
| 	return false, nil | ||||
| } | ||||
|  |  | |||
|  | @ -4,12 +4,13 @@ import ( | |||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
| 
 | ||||
| func TestRefresh(t *testing.T) { | ||||
| 	p := &ProviderData{} | ||||
| 	refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ | ||||
| 	refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{ | ||||
| 		ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), | ||||
| 	}) | ||||
| 	assert.Equal(t, false, refreshed) | ||||
|  |  | |||
|  | @ -2,20 +2,21 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"github.com/pusher/oauth2_proxy/cookie" | ||||
| 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // Provider represents an upstream identity provider implementation
 | ||||
| type Provider interface { | ||||
| 	Data() *ProviderData | ||||
| 	GetEmailAddress(*SessionState) (string, error) | ||||
| 	GetUserName(*SessionState) (string, error) | ||||
| 	Redeem(string, string) (*SessionState, error) | ||||
| 	GetEmailAddress(*sessions.SessionState) (string, error) | ||||
| 	GetUserName(*sessions.SessionState) (string, error) | ||||
| 	Redeem(string, string) (*sessions.SessionState, error) | ||||
| 	ValidateGroup(string) bool | ||||
| 	ValidateSessionState(*SessionState) bool | ||||
| 	ValidateSessionState(*sessions.SessionState) bool | ||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | ||||
| 	RefreshSessionIfNeeded(*SessionState) (bool, error) | ||||
| 	SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) | ||||
| 	CookieForSession(*SessionState, *cookie.Cipher) (string, error) | ||||
| 	RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) | ||||
| 	SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error) | ||||
| 	CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error) | ||||
| } | ||||
| 
 | ||||
| // New provides a new Provider based on the configured provider string
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue