Merge pull request #147 from pusher/session-store
Add initial session-store interface and implementation
This commit is contained in:
		
						commit
						17e97ab884
					
				|  | @ -10,6 +10,10 @@ | ||||||
| 
 | 
 | ||||||
| ## Changes since v3.2.0 | ## Changes since v3.2.0 | ||||||
| 
 | 
 | ||||||
|  | - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) | ||||||
|  |   - Allows for multiple different session storage implementations including client and server side | ||||||
|  |   - Adds tests suite for interface to ensure consistency across implementations | ||||||
|  |   - Refactor some configuration options (around cookies) into packages | ||||||
| - [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings) | - [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings) | ||||||
| - [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath) | - [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath) | ||||||
| - [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes) | - [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes) | ||||||
|  |  | ||||||
|  | @ -57,6 +57,20 @@ | ||||||
|   pruneopts = "" |   pruneopts = "" | ||||||
|   revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" |   revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" | ||||||
| 
 | 
 | ||||||
|  | [[projects]] | ||||||
|  |   digest = "1:b3c5b95e56c06f5aa72cb2500e6ee5f44fcd122872d4fec2023a488e561218bc" | ||||||
|  |   name = "github.com/hpcloud/tail" | ||||||
|  |   packages = [ | ||||||
|  |     ".", | ||||||
|  |     "ratelimiter", | ||||||
|  |     "util", | ||||||
|  |     "watch", | ||||||
|  |     "winfile", | ||||||
|  |   ] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5" | ||||||
|  |   version = "v1.0.0" | ||||||
|  | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" |   digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" | ||||||
|   name = "github.com/mbland/hmacauth" |   name = "github.com/mbland/hmacauth" | ||||||
|  | @ -73,6 +87,54 @@ | ||||||
|   pruneopts = "" |   pruneopts = "" | ||||||
|   revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" |   revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" | ||||||
| 
 | 
 | ||||||
|  | [[projects]] | ||||||
|  |   digest = "1:a3735b0978a8b53fc2ac97a6f46ca1189f0712a00df86d6ec4cf26c1a25e6d77" | ||||||
|  |   name = "github.com/onsi/ginkgo" | ||||||
|  |   packages = [ | ||||||
|  |     ".", | ||||||
|  |     "config", | ||||||
|  |     "internal/codelocation", | ||||||
|  |     "internal/containernode", | ||||||
|  |     "internal/failer", | ||||||
|  |     "internal/leafnodes", | ||||||
|  |     "internal/remote", | ||||||
|  |     "internal/spec", | ||||||
|  |     "internal/spec_iterator", | ||||||
|  |     "internal/specrunner", | ||||||
|  |     "internal/suite", | ||||||
|  |     "internal/testingtproxy", | ||||||
|  |     "internal/writer", | ||||||
|  |     "reporters", | ||||||
|  |     "reporters/stenographer", | ||||||
|  |     "reporters/stenographer/support/go-colorable", | ||||||
|  |     "reporters/stenographer/support/go-isatty", | ||||||
|  |     "types", | ||||||
|  |   ] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "eea6ad008b96acdaa524f5b409513bf062b500ad" | ||||||
|  |   version = "v1.8.0" | ||||||
|  | 
 | ||||||
|  | [[projects]] | ||||||
|  |   digest = "1:dbafce2fddb1ca331646fe2ac9c9413980368b19a60a4406a6e5861680bd73be" | ||||||
|  |   name = "github.com/onsi/gomega" | ||||||
|  |   packages = [ | ||||||
|  |     ".", | ||||||
|  |     "format", | ||||||
|  |     "internal/assertion", | ||||||
|  |     "internal/asyncassertion", | ||||||
|  |     "internal/oraclematcher", | ||||||
|  |     "internal/testingtsupport", | ||||||
|  |     "matchers", | ||||||
|  |     "matchers/support/goraph/bipartitegraph", | ||||||
|  |     "matchers/support/goraph/edge", | ||||||
|  |     "matchers/support/goraph/node", | ||||||
|  |     "matchers/support/goraph/util", | ||||||
|  |     "types", | ||||||
|  |   ] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "90e289841c1ed79b7a598a7cd9959750cb5e89e2" | ||||||
|  |   version = "v1.5.0" | ||||||
|  | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" |   digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" | ||||||
|   name = "github.com/pmezard/go-difflib" |   name = "github.com/pmezard/go-difflib" | ||||||
|  | @ -131,6 +193,9 @@ | ||||||
|   packages = [ |   packages = [ | ||||||
|     "context", |     "context", | ||||||
|     "context/ctxhttp", |     "context/ctxhttp", | ||||||
|  |     "html", | ||||||
|  |     "html/atom", | ||||||
|  |     "html/charset", | ||||||
|     "websocket", |     "websocket", | ||||||
|   ] |   ] | ||||||
|   pruneopts = "" |   pruneopts = "" | ||||||
|  | @ -150,6 +215,42 @@ | ||||||
|   pruneopts = "" |   pruneopts = "" | ||||||
|   revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" |   revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" | ||||||
| 
 | 
 | ||||||
|  | [[projects]] | ||||||
|  |   branch = "master" | ||||||
|  |   digest = "1:67a6e61e60283fd7dce50eba228080bff8805d9d69b2f121d7ec2260d120c4a8" | ||||||
|  |   name = "golang.org/x/sys" | ||||||
|  |   packages = ["unix"] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "ca7f33d4116e3a1f9425755d6a44e7ed9b4c97df" | ||||||
|  | 
 | ||||||
|  | [[projects]] | ||||||
|  |   digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" | ||||||
|  |   name = "golang.org/x/text" | ||||||
|  |   packages = [ | ||||||
|  |     "encoding", | ||||||
|  |     "encoding/charmap", | ||||||
|  |     "encoding/htmlindex", | ||||||
|  |     "encoding/internal", | ||||||
|  |     "encoding/internal/identifier", | ||||||
|  |     "encoding/japanese", | ||||||
|  |     "encoding/korean", | ||||||
|  |     "encoding/simplifiedchinese", | ||||||
|  |     "encoding/traditionalchinese", | ||||||
|  |     "encoding/unicode", | ||||||
|  |     "internal/gen", | ||||||
|  |     "internal/language", | ||||||
|  |     "internal/language/compact", | ||||||
|  |     "internal/tag", | ||||||
|  |     "internal/utf8internal", | ||||||
|  |     "language", | ||||||
|  |     "runes", | ||||||
|  |     "transform", | ||||||
|  |     "unicode/cldr", | ||||||
|  |   ] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" | ||||||
|  |   version = "v0.3.2" | ||||||
|  | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|   digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" |   digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" | ||||||
|  | @ -182,6 +283,15 @@ | ||||||
|   revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" |   revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" | ||||||
|   version = "v1.0.0" |   version = "v1.0.0" | ||||||
| 
 | 
 | ||||||
|  | [[projects]] | ||||||
|  |   digest = "1:eb53021a8aa3f599d29c7102e65026242bdedce998a54837dc67f14b6a97c5fd" | ||||||
|  |   name = "gopkg.in/fsnotify.v1" | ||||||
|  |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" | ||||||
|  |   source = "https://github.com/fsnotify/fsnotify.git" | ||||||
|  |   version = "v1.4.7" | ||||||
|  | 
 | ||||||
| [[projects]] | [[projects]] | ||||||
|   digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" |   digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" | ||||||
|   name = "gopkg.in/fsnotify/fsnotify.v1" |   name = "gopkg.in/fsnotify/fsnotify.v1" | ||||||
|  | @ -210,6 +320,22 @@ | ||||||
|   revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" |   revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" | ||||||
|   version = "v2.1.3" |   version = "v2.1.3" | ||||||
| 
 | 
 | ||||||
|  | [[projects]] | ||||||
|  |   branch = "v1" | ||||||
|  |   digest = "1:a96d16bd088460f2e0685d46c39bcf1208ba46e0a977be2df49864ec7da447dd" | ||||||
|  |   name = "gopkg.in/tomb.v1" | ||||||
|  |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8" | ||||||
|  | 
 | ||||||
|  | [[projects]] | ||||||
|  |   digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" | ||||||
|  |   name = "gopkg.in/yaml.v2" | ||||||
|  |   packages = ["."] | ||||||
|  |   pruneopts = "" | ||||||
|  |   revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" | ||||||
|  |   version = "v2.2.2" | ||||||
|  | 
 | ||||||
| [solve-meta] | [solve-meta] | ||||||
|   analyzer-name = "dep" |   analyzer-name = "dep" | ||||||
|   analyzer-version = 1 |   analyzer-version = 1 | ||||||
|  | @ -220,6 +346,8 @@ | ||||||
|     "github.com/dgrijalva/jwt-go", |     "github.com/dgrijalva/jwt-go", | ||||||
|     "github.com/mbland/hmacauth", |     "github.com/mbland/hmacauth", | ||||||
|     "github.com/mreiferson/go-options", |     "github.com/mreiferson/go-options", | ||||||
|  |     "github.com/onsi/ginkgo", | ||||||
|  |     "github.com/onsi/gomega", | ||||||
|     "github.com/stretchr/testify/assert", |     "github.com/stretchr/testify/assert", | ||||||
|     "github.com/stretchr/testify/require", |     "github.com/stretchr/testify/require", | ||||||
|     "github.com/yhat/wsutil", |     "github.com/yhat/wsutil", | ||||||
|  | @ -231,6 +359,7 @@ | ||||||
|     "google.golang.org/api/googleapi", |     "google.golang.org/api/googleapi", | ||||||
|     "gopkg.in/fsnotify/fsnotify.v1", |     "gopkg.in/fsnotify/fsnotify.v1", | ||||||
|     "gopkg.in/natefinch/lumberjack.v2", |     "gopkg.in/natefinch/lumberjack.v2", | ||||||
|  |     "gopkg.in/square/go-jose.v2", | ||||||
|   ] |   ] | ||||||
|   solver-name = "gps-cdcl" |   solver-name = "gps-cdcl" | ||||||
|   solver-version = 1 |   solver-version = 1 | ||||||
|  |  | ||||||
|  | @ -35,6 +35,10 @@ | ||||||
|   name = "gopkg.in/fsnotify/fsnotify.v1" |   name = "gopkg.in/fsnotify/fsnotify.v1" | ||||||
|   version = "~1.2.0" |   version = "~1.2.0" | ||||||
| 
 | 
 | ||||||
|  | [[override]] | ||||||
|  |   name = "gopkg.in/fsnotify.v1" | ||||||
|  |   source = "https://github.com/fsnotify/fsnotify.git" | ||||||
|  | 
 | ||||||
| [[constraint]] | [[constraint]] | ||||||
|   branch = "master" |   branch = "master" | ||||||
|   name = "golang.org/x/crypto" |   name = "golang.org/x/crypto" | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								Makefile
								
								
								
								
							
							
						
						
									
										1
									
								
								Makefile
								
								
								
								
							|  | @ -33,6 +33,7 @@ lint: $(GOMETALINTER) | ||||||
| 		--enable=deadcode \
 | 		--enable=deadcode \
 | ||||||
| 		--enable=gofmt \
 | 		--enable=gofmt \
 | ||||||
| 		--enable=goimports \
 | 		--enable=goimports \
 | ||||||
|  | 		--deadline=120s \
 | ||||||
| 		--tests ./... | 		--tests ./... | ||||||
| 
 | 
 | ||||||
| .PHONY: dep | .PHONY: dep | ||||||
|  |  | ||||||
|  | @ -1,7 +1,8 @@ | ||||||
| --- | --- | ||||||
| layout: default | layout: default | ||||||
| title: Configuration | title: Configuration | ||||||
| permalink: /configuration | permalink: /docs/configuration | ||||||
|  | has_children: true | ||||||
| nav_order: 3 | nav_order: 3 | ||||||
| --- | --- | ||||||
| 
 | 
 | ||||||
|  | @ -78,6 +79,7 @@ Usage of oauth2_proxy: | ||||||
|   -request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below) |   -request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below) | ||||||
|   -resource string: The resource that is protected (Azure AD only) |   -resource string: The resource that is protected (Azure AD only) | ||||||
|   -scope string: OAuth scope specification |   -scope string: OAuth scope specification | ||||||
|  |   -session-store-type: Session data storage backend (default: cookie) | ||||||
|   -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) |   -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) | ||||||
|   -set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode) |   -set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode) | ||||||
|   -signature-key string: GAP-Signature request signature key (algorithm:secretkey) |   -signature-key string: GAP-Signature request signature key (algorithm:secretkey) | ||||||
|  | @ -0,0 +1,34 @@ | ||||||
|  | --- | ||||||
|  | layout: default | ||||||
|  | title: Sessions | ||||||
|  | permalink: /configuration | ||||||
|  | parent: Configuration | ||||||
|  | nav_order: 3 | ||||||
|  | --- | ||||||
|  | 
 | ||||||
|  | ## Sessions | ||||||
|  | 
 | ||||||
|  | Sessions allow a user's authentication to be tracked between multiple HTTP | ||||||
|  | requests to a service. | ||||||
|  | 
 | ||||||
|  | The OAuth2 Proxy uses a Cookie to track user sessions and will store the session | ||||||
|  | data in one of the available session storage backends. | ||||||
|  | 
 | ||||||
|  | At present the available backends are (as passed to `--session-store-type`): | ||||||
|  | - [cookie](cookie-storage) (deafult) | ||||||
|  | 
 | ||||||
|  | ### Cookie Storage | ||||||
|  | 
 | ||||||
|  | The Cookie storage backend is the default backend implementation and has | ||||||
|  | been used in the OAuth2 Proxy historically. | ||||||
|  | 
 | ||||||
|  | With the Cookie storage backend, all session information is stored in client | ||||||
|  | side cookies and transferred with each and every request. | ||||||
|  | 
 | ||||||
|  | The following should be known when using this implementation: | ||||||
|  | - Since all state is stored client side, this storage backend means that the OAuth2 Proxy is completely stateless | ||||||
|  | - Cookies are signed server side to prevent modification client-side | ||||||
|  | - It is recommended to set a `cookie-secret` which will ensure data is encrypted within the cookie data. | ||||||
|  | - Since multiple requests can be made concurrently to the OAuth2 Proxy, this session implementation | ||||||
|  | cannot lock sessions and while updating and refreshing sessions, there can be conflicts which force | ||||||
|  | users to re-authenticate | ||||||
|  | @ -15,14 +15,27 @@ type EnvOptions map[string]interface{} | ||||||
| // Fields in the options struct must have an `env` and `cfg` tag to be read
 | // Fields in the options struct must have an `env` and `cfg` tag to be read
 | ||||||
| // from the environment
 | // from the environment
 | ||||||
| func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { | func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { | ||||||
| 	val := reflect.ValueOf(options).Elem() | 	val := reflect.ValueOf(options) | ||||||
| 	typ := val.Type() | 	var typ reflect.Type | ||||||
|  | 	if val.Kind() == reflect.Ptr { | ||||||
|  | 		typ = val.Elem().Type() | ||||||
|  | 	} else { | ||||||
|  | 		typ = val.Type() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	for i := 0; i < typ.NumField(); i++ { | 	for i := 0; i < typ.NumField(); i++ { | ||||||
| 		// pull out the struct tags:
 | 		// pull out the struct tags:
 | ||||||
| 		//    flag - the name of the command line flag
 | 		//    flag - the name of the command line flag
 | ||||||
| 		//    deprecated - (optional) the name of the deprecated command line flag
 | 		//    deprecated - (optional) the name of the deprecated command line flag
 | ||||||
| 		//    cfg - (optional, defaults to underscored flag) the name of the config file option
 | 		//    cfg - (optional, defaults to underscored flag) the name of the config file option
 | ||||||
| 		field := typ.Field(i) | 		field := typ.Field(i) | ||||||
|  | 		fieldV := reflect.Indirect(val).Field(i) | ||||||
|  | 
 | ||||||
|  | 		if field.Type.Kind() == reflect.Struct && field.Anonymous { | ||||||
|  | 			cfg.LoadEnvForStruct(fieldV.Interface()) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
| 		flagName := field.Tag.Get("flag") | 		flagName := field.Tag.Get("flag") | ||||||
| 		envName := field.Tag.Get("env") | 		envName := field.Tag.Get("env") | ||||||
| 		cfgName := field.Tag.Get("cfg") | 		cfgName := field.Tag.Get("cfg") | ||||||
|  |  | ||||||
|  | @ -1,26 +1,46 @@ | ||||||
| package main | package main_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"os" | 	"os" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	proxy "github.com/pusher/oauth2_proxy" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type envTest struct { | type EnvTest struct { | ||||||
| 	testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` | 	TestField string `cfg:"target_field" env:"TEST_ENV_FIELD"` | ||||||
|  | 	EnvTestEmbed | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type EnvTestEmbed struct { | ||||||
|  | 	TestFieldEmbed string `cfg:"target_field_embed" env:"TEST_ENV_FIELD_EMBED"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestLoadEnvForStruct(t *testing.T) { | func TestLoadEnvForStruct(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	cfg := make(EnvOptions) | 	cfg := make(proxy.EnvOptions) | ||||||
| 	cfg.LoadEnvForStruct(&envTest{}) | 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||||
| 
 | 
 | ||||||
| 	_, ok := cfg["target_field"] | 	_, ok := cfg["target_field"] | ||||||
| 	assert.Equal(t, ok, false) | 	assert.Equal(t, ok, false) | ||||||
| 
 | 
 | ||||||
| 	os.Setenv("TEST_ENV_FIELD", "1234abcd") | 	os.Setenv("TEST_ENV_FIELD", "1234abcd") | ||||||
| 	cfg.LoadEnvForStruct(&envTest{}) | 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||||
| 	v := cfg["target_field"] | 	v := cfg["target_field"] | ||||||
| 	assert.Equal(t, v, "1234abcd") | 	assert.Equal(t, v, "1234abcd") | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func TestLoadEnvForStructWithEmbeddedFields(t *testing.T) { | ||||||
|  | 
 | ||||||
|  | 	cfg := make(proxy.EnvOptions) | ||||||
|  | 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||||
|  | 
 | ||||||
|  | 	_, ok := cfg["target_field_embed"] | ||||||
|  | 	assert.Equal(t, ok, false) | ||||||
|  | 
 | ||||||
|  | 	os.Setenv("TEST_ENV_FIELD_EMBED", "1234abcd") | ||||||
|  | 	cfg.LoadEnvForStruct(&EnvTest{}) | ||||||
|  | 	v := cfg["target_field_embed"] | ||||||
|  | 	assert.Equal(t, v, "1234abcd") | ||||||
|  | } | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								main.go
								
								
								
								
							
							
						
						
									
										2
									
								
								main.go
								
								
								
								
							|  | @ -75,6 +75,8 @@ func main() { | ||||||
| 	flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") | 	flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") | ||||||
| 	flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") | 	flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") | ||||||
| 
 | 
 | ||||||
|  | 	flagSet.String("session-store-type", "cookie", "the session storage provider to use") | ||||||
|  | 
 | ||||||
| 	flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") | 	flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") | ||||||
| 	flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation") | 	flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation") | ||||||
| 	flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files") | 	flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files") | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ import ( | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/yhat/wsutil" | 	"github.com/yhat/wsutil" | ||||||
| ) | ) | ||||||
|  | @ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { | ||||||
| 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | 	return p.HtpasswdFile != nil && p.DisplayHtpasswdForm | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { | func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, errors.New("missing code") | 		return nil, errors.New("missing code") | ||||||
| 	} | 	} | ||||||
|  | @ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // LoadCookiedSession reads the user's authentication details from the request
 | // LoadCookiedSession reads the user's authentication details from the request
 | ||||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { | func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { | ||||||
| 	var age time.Duration | 	var age time.Duration | ||||||
| 	c, err := loadCookie(req, p.CookieName) | 	c, err := loadCookie(req, p.CookieName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | @ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SaveSession creates a new session cookie value and sets this on the response
 | // SaveSession creates a new session cookie value and sets this on the response
 | ||||||
| func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { | func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||||
| 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | 	value, err := p.provider.CookieForSession(s, p.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|  | @ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { | ||||||
| 
 | 
 | ||||||
| 	user, ok := p.ManualSignIn(rw, req) | 	user, ok := p.ManualSignIn(rw, req) | ||||||
| 	if ok { | 	if ok { | ||||||
| 		session := &providers.SessionState{User: user} | 		session := &sessions.SessionState{User: user} | ||||||
| 		p.SaveSession(rw, req, session) | 		p.SaveSession(rw, req, session) | ||||||
| 		http.Redirect(rw, req, redirect, 302) | 		http.Redirect(rw, req, redirect, 302) | ||||||
| 	} else { | 	} else { | ||||||
|  | @ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | ||||||
| 
 | 
 | ||||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | // credentials and authenticates these against the proxies HtpasswdFile
 | ||||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { | func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { | ||||||
| 	if p.HtpasswdFile == nil { | 	if p.HtpasswdFile == nil { | ||||||
| 		return nil, nil | 		return nil, nil | ||||||
| 	} | 	} | ||||||
|  | @ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, | ||||||
| 	} | 	} | ||||||
| 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | 	if p.HtpasswdFile.Validate(pair[0], pair[1]) { | ||||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||||
| 		return &providers.SessionState{User: pair[0]}, nil | 		return &sessions.SessionState{User: pair[0]}, nil | ||||||
| 	} | 	} | ||||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| 	"github.com/stretchr/testify/require" | 	"github.com/stretchr/testify/require" | ||||||
|  | @ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { | func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { | ||||||
| 	return tp.EmailAddress, nil | 	return tp.EmailAddress, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { | func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { | ||||||
| 	return tp.ValidToken | 	return tp.ValidToken | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook | ||||||
| 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { | func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { | ||||||
| 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | 	value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|  | @ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { | func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { | ||||||
| 	return p.proxy.LoadCookiedSession(p.req) | 	return p.proxy.LoadCookiedSession(p.req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestLoadCookiedSession(t *testing.T) { | func TestLoadCookiedSession(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 
 | 
 | ||||||
| 	startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} | ||||||
| 	pcTest.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession, time.Now()) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, _, err := pcTest.LoadCookiedSession() | ||||||
|  | @ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | 	pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | 	reference := time.Now().Add(time.Duration(-2) * time.Hour) | ||||||
| 
 | 
 | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
| 	session, age, err := pcTest.LoadCookiedSession() | 	session, age, err := pcTest.LoadCookiedSession() | ||||||
|  | @ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
| 	session, _, err := pcTest.LoadCookiedSession() | 	session, _, err := pcTest.LoadCookiedSession() | ||||||
|  | @ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { | ||||||
| 	pcTest := NewProcessCookieTestWithDefaults() | 	pcTest := NewProcessCookieTestWithDefaults() | ||||||
| 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | 	pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 	startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	pcTest.SaveSession(startSession, reference) | 	pcTest.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
| 	pcTest.proxy.CookieRefresh = time.Hour | 	pcTest.proxy.CookieRefresh = time.Hour | ||||||
|  | @ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyEndpointAccepted(t *testing.T) { | func TestAuthOnlyEndpointAccepted(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	startSession := &providers.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession, time.Now()) | ||||||
| 
 | 
 | ||||||
|  | @ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	test.proxy.CookieExpire = time.Duration(24) * time.Hour | 	test.proxy.CookieExpire = time.Duration(24) * time.Hour | ||||||
| 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | 	reference := time.Now().Add(time.Duration(25) * time.Hour * -1) | ||||||
| 	startSession := &providers.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	test.SaveSession(startSession, reference) | 	test.SaveSession(startSession, reference) | ||||||
| 
 | 
 | ||||||
|  | @ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { | ||||||
| 	test := NewAuthOnlyEndpointTest() | 	test := NewAuthOnlyEndpointTest() | ||||||
| 	startSession := &providers.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | 		Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} | ||||||
| 	test.SaveSession(startSession, time.Now()) | 	test.SaveSession(startSession, time.Now()) | ||||||
| 	test.validateUser = false | 	test.validateUser = false | ||||||
|  | @ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { | ||||||
| 	pcTest.req, _ = http.NewRequest("GET", | 	pcTest.req, _ = http.NewRequest("GET", | ||||||
| 		pcTest.opts.ProxyPrefix+"/auth", nil) | 		pcTest.opts.ProxyPrefix+"/auth", nil) | ||||||
| 
 | 
 | ||||||
| 	startSession := &providers.SessionState{ | 	startSession := &sessions.SessionState{ | ||||||
| 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | 		User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} | ||||||
| 	pcTest.SaveSession(startSession, time.Now()) | 	pcTest.SaveSession(startSession, time.Now()) | ||||||
| 
 | 
 | ||||||
|  | @ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { | ||||||
| 	req := httptest.NewRequest(method, "/foo/bar", bodyBuf) | 	req := httptest.NewRequest(method, "/foo/bar", bodyBuf) | ||||||
| 	req.Header = st.header | 	req.Header = st.header | ||||||
| 
 | 
 | ||||||
| 	state := &providers.SessionState{ | 	state := &sessions.SessionState{ | ||||||
| 		Email: "mbland@acm.org", AccessToken: "my_access_token"} | 		Email: "mbland@acm.org", AccessToken: "my_access_token"} | ||||||
| 	value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) | 	value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
							
								
								
									
										16
									
								
								options.go
								
								
								
								
							
							
						
						
									
										16
									
								
								options.go
								
								
								
								
							|  | @ -18,6 +18,7 @@ import ( | ||||||
| 	"github.com/dgrijalva/jwt-go" | 	"github.com/dgrijalva/jwt-go" | ||||||
| 	"github.com/mbland/hmacauth" | 	"github.com/mbland/hmacauth" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
| 	"github.com/pusher/oauth2_proxy/providers" | 	"github.com/pusher/oauth2_proxy/providers" | ||||||
| 	"gopkg.in/natefinch/lumberjack.v2" | 	"gopkg.in/natefinch/lumberjack.v2" | ||||||
| ) | ) | ||||||
|  | @ -49,14 +50,11 @@ type Options struct { | ||||||
| 	CustomTemplatesDir       string   `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"` | 	CustomTemplatesDir       string   `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"` | ||||||
| 	Footer                   string   `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"` | 	Footer                   string   `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"` | ||||||
| 
 | 
 | ||||||
| 	CookieName     string        `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` | 	// Embed CookieOptions
 | ||||||
| 	CookieSecret   string        `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` | 	options.CookieOptions | ||||||
| 	CookieDomain   string        `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` | 
 | ||||||
| 	CookiePath     string        `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` | 	// Embed SessionOptions
 | ||||||
| 	CookieExpire   time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` | 	options.SessionOptions | ||||||
| 	CookieRefresh  time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` |  | ||||||
| 	CookieSecure   bool          `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` |  | ||||||
| 	CookieHTTPOnly bool          `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` |  | ||||||
| 
 | 
 | ||||||
| 	Upstreams             []string      `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"` | 	Upstreams             []string      `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"` | ||||||
| 	SkipAuthRegex         []string      `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"` | 	SkipAuthRegex         []string      `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"` | ||||||
|  | @ -131,11 +129,13 @@ func NewOptions() *Options { | ||||||
| 		HTTPAddress:         "127.0.0.1:4180", | 		HTTPAddress:         "127.0.0.1:4180", | ||||||
| 		HTTPSAddress:        ":443", | 		HTTPSAddress:        ":443", | ||||||
| 		DisplayHtpasswdForm: true, | 		DisplayHtpasswdForm: true, | ||||||
|  | 		CookieOptions: options.CookieOptions{ | ||||||
| 			CookieName:     "_oauth2_proxy", | 			CookieName:     "_oauth2_proxy", | ||||||
| 			CookieSecure:   true, | 			CookieSecure:   true, | ||||||
| 			CookieHTTPOnly: true, | 			CookieHTTPOnly: true, | ||||||
| 			CookieExpire:   time.Duration(168) * time.Hour, | 			CookieExpire:   time.Duration(168) * time.Hour, | ||||||
| 			CookieRefresh:  time.Duration(0), | 			CookieRefresh:  time.Duration(0), | ||||||
|  | 		}, | ||||||
| 		SetXAuthRequest:       false, | 		SetXAuthRequest:       false, | ||||||
| 		SkipAuthPreflight:     false, | 		SkipAuthPreflight:     false, | ||||||
| 		PassBasicAuth:         true, | 		PassBasicAuth:         true, | ||||||
|  |  | ||||||
|  | @ -0,0 +1,15 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | import "time" | ||||||
|  | 
 | ||||||
|  | // CookieOptions contains configuration options relating to Cookie configuration
 | ||||||
|  | type CookieOptions struct { | ||||||
|  | 	CookieName     string        `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` | ||||||
|  | 	CookieSecret   string        `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` | ||||||
|  | 	CookieDomain   string        `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` | ||||||
|  | 	CookiePath     string        `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` | ||||||
|  | 	CookieExpire   time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` | ||||||
|  | 	CookieRefresh  time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` | ||||||
|  | 	CookieSecure   bool          `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` | ||||||
|  | 	CookieHTTPOnly bool          `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` | ||||||
|  | } | ||||||
|  | @ -0,0 +1,14 @@ | ||||||
|  | package options | ||||||
|  | 
 | ||||||
|  | // SessionOptions contains configuration options for the SessionStore providers.
 | ||||||
|  | type SessionOptions struct { | ||||||
|  | 	Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` | ||||||
|  | 	CookieStoreOptions | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CookieSessionStoreType is used to indicate the CookieSessionStore should be
 | ||||||
|  | // used for storing sessions.
 | ||||||
|  | var CookieSessionStoreType = "cookie" | ||||||
|  | 
 | ||||||
|  | // CookieStoreOptions contains configuration options for the CookieSessionStore.
 | ||||||
|  | type CookieStoreOptions struct{} | ||||||
|  | @ -0,0 +1,12 @@ | ||||||
|  | package sessions | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // SessionStore is an interface to storing user sessions in the proxy
 | ||||||
|  | type SessionStore interface { | ||||||
|  | 	Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error | ||||||
|  | 	Load(req *http.Request) (*SessionState, error) | ||||||
|  | 	Clear(rw http.ResponseWriter, req *http.Request) error | ||||||
|  | } | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| package providers | package sessions | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| package providers | package sessions_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := cookie.NewCipher([]byte(altSecret)) | 	c2, err := cookie.NewCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
| 		IDToken:      "rawtoken1234", | 		IDToken:      "rawtoken1234", | ||||||
|  | @ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	encoded, err := s.EncodeSessionState(c) | 	encoded, err := s.EncodeSessionState(c) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	ss, err := DecodeSessionState(encoded, c) | 	ss, err := sessions.DecodeSessionState(encoded, c) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@domain.com", ss.User) | 	assert.Equal(t, "user@domain.com", ss.User) | ||||||
|  | @ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||||
| 	ss, err = DecodeSessionState(encoded, c2) | 	ss, err = sessions.DecodeSessionState(encoded, c2) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.NotEqual(t, "user@domain.com", ss.User) | 	assert.NotEqual(t, "user@domain.com", ss.User) | ||||||
|  | @ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	c2, err := cookie.NewCipher([]byte(altSecret)) | 	c2, err := cookie.NewCipher([]byte(altSecret)) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	s := &SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		User:         "just-user", | 		User:         "just-user", | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
|  | @ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	encoded, err := s.EncodeSessionState(c) | 	encoded, err := s.EncodeSessionState(c) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	ss, err := DecodeSessionState(encoded, c) | 	ss, err := sessions.DecodeSessionState(encoded, c) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, s.User, ss.User) | 	assert.Equal(t, s.User, ss.User) | ||||||
|  | @ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
| 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | 	// ensure a different cipher can't decode properly (ie: it gets gibberish)
 | ||||||
| 	ss, err = DecodeSessionState(encoded, c2) | 	ss, err = sessions.DecodeSessionState(encoded, c2) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.NotEqual(t, s.User, ss.User) | 	assert.NotEqual(t, s.User, ss.User) | ||||||
|  | @ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerializationNoCipher(t *testing.T) { | func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| 	s := &SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||||
|  | @ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	// only email should have been serialized
 | 	// only email should have been serialized
 | ||||||
| 	ss, err := DecodeSessionState(encoded, nil) | 	ss, err := sessions.DecodeSessionState(encoded, nil) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@domain.com", ss.User) | 	assert.Equal(t, "user@domain.com", ss.User) | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
|  | @ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| 	s := &SessionState{ | 	s := &sessions.SessionState{ | ||||||
| 		User:         "just-user", | 		User:         "just-user", | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
|  | @ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 
 | 
 | ||||||
| 	// only email should have been serialized
 | 	// only email should have been serialized
 | ||||||
| 	ss, err := DecodeSessionState(encoded, nil) | 	ss, err := sessions.DecodeSessionState(encoded, nil) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, s.User, ss.User) | 	assert.Equal(t, s.User, ss.User) | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
|  | @ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestExpired(t *testing.T) { | func TestExpired(t *testing.T) { | ||||||
| 	s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} | 	s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} | ||||||
| 	assert.Equal(t, true, s.IsExpired()) | 	assert.Equal(t, true, s.IsExpired()) | ||||||
| 
 | 
 | ||||||
| 	s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} | 	s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} | ||||||
| 	assert.Equal(t, false, s.IsExpired()) | 	assert.Equal(t, false, s.IsExpired()) | ||||||
| 
 | 
 | ||||||
| 	s = &SessionState{} | 	s = &sessions.SessionState{} | ||||||
| 	assert.Equal(t, false, s.IsExpired()) | 	assert.Equal(t, false, s.IsExpired()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type testCase struct { | type testCase struct { | ||||||
| 	SessionState | 	sessions.SessionState | ||||||
| 	Encoded string | 	Encoded string | ||||||
| 	Cipher  *cookie.Cipher | 	Cipher  *cookie.Cipher | ||||||
| 	Error   bool | 	Error   bool | ||||||
|  | @ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email:        "user@domain.com", | 				Email:        "user@domain.com", | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
|  | @ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	for i, tc := range testCases { | 	for i, tc := range testCases { | ||||||
| 		encoded, err := tc.EncodeSessionState(tc.Cipher) | 		encoded, err := tc.EncodeSessionState(tc.Cipher) | ||||||
| 		t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) | ||||||
| 		if tc.Error { | 		if tc.Error { | ||||||
| 			assert.Error(t, err) | 			assert.Error(t, err) | ||||||
| 			assert.Empty(t, encoded) | 			assert.Empty(t, encoded) | ||||||
|  | @ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // TestDecodeSessionState tests DecodeSessionState with the test vector
 | // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
 | ||||||
| func TestDecodeSessionState(t *testing.T) { | func TestDecodeSessionState(t *testing.T) { | ||||||
| 	e := time.Now().Add(time.Duration(1) * time.Hour) | 	e := time.Now().Add(time.Duration(1) * time.Hour) | ||||||
| 	eJSON, _ := e.MarshalJSON() | 	eJSON, _ := e.MarshalJSON() | ||||||
|  | @ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	testCases := []testCase{ | 	testCases := []testCase{ | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | 			Encoded: `{"Email":"user@domain.com","User":"just-user"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "user@domain.com", | 				User:  "user@domain.com", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"Email":"user@domain.com"}`, | 			Encoded: `{"Email":"user@domain.com"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				User: "just-user", | 				User: "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: `{"User":"just-user"}`, | 			Encoded: `{"User":"just-user"}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
| 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | 			Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email:        "user@domain.com", | 				Email:        "user@domain.com", | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
|  | @ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 			Cipher:  c, | 			Cipher:  c, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 			}, | 			}, | ||||||
|  | @ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 			Error:   true, | 			Error:   true, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				User:  "just-user", | 				User:  "just-user", | ||||||
| 				Email: "user@domain.com", | 				Email: "user@domain.com", | ||||||
| 			}, | 			}, | ||||||
|  | @ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 			Error:   true, | 			Error:   true, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email:        "user@domain.com", | 				Email:        "user@domain.com", | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
|  | @ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 			Cipher:  c, | 			Cipher:  c, | ||||||
| 		}, | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			SessionState: SessionState{ | 			SessionState: sessions.SessionState{ | ||||||
| 				Email:        "user@domain.com", | 				Email:        "user@domain.com", | ||||||
| 				User:         "just-user", | 				User:         "just-user", | ||||||
| 				AccessToken:  "token1234", | 				AccessToken:  "token1234", | ||||||
|  | @ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for i, tc := range testCases { | 	for i, tc := range testCases { | ||||||
| 		ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) | 		ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) | ||||||
| 		t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | 		t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) | ||||||
| 		if tc.Error { | 		if tc.Error { | ||||||
| 			assert.Error(t, err) | 			assert.Error(t, err) | ||||||
| 			assert.Nil(t, ss) | 			assert.Nil(t, ss) | ||||||
|  | @ -0,0 +1,34 @@ | ||||||
|  | package cookies | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // MakeCookie constructs a cookie from the given parameters,
 | ||||||
|  | // discovering the domain from the request if not specified.
 | ||||||
|  | func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time) *http.Cookie { | ||||||
|  | 	if domain != "" { | ||||||
|  | 		host := req.Host | ||||||
|  | 		if h, _, err := net.SplitHostPort(host); err == nil { | ||||||
|  | 			host = h | ||||||
|  | 		} | ||||||
|  | 		if !strings.HasSuffix(host, domain) { | ||||||
|  | 			logger.Printf("Warning: request host is %q but using configured cookie domain of %q", host, domain) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &http.Cookie{ | ||||||
|  | 		Name:     name, | ||||||
|  | 		Value:    value, | ||||||
|  | 		Path:     path, | ||||||
|  | 		Domain:   domain, | ||||||
|  | 		HttpOnly: httpOnly, | ||||||
|  | 		Secure:   secure, | ||||||
|  | 		Expires:  now.Add(expiration), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,232 @@ | ||||||
|  | package cookie | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"regexp" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/cookies" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions/utils" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	// Cookies are limited to 4kb including the length of the cookie name,
 | ||||||
|  | 	// the cookie name can be up to 256 bytes
 | ||||||
|  | 	maxCookieLength = 3840 | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // Ensure CookieSessionStore implements the interface
 | ||||||
|  | var _ sessions.SessionStore = &SessionStore{} | ||||||
|  | 
 | ||||||
|  | // SessionStore is an implementation of the sessions.SessionStore
 | ||||||
|  | // interface that stores sessions in client side cookies
 | ||||||
|  | type SessionStore struct { | ||||||
|  | 	CookieCipher   *cookie.Cipher | ||||||
|  | 	CookieDomain   string | ||||||
|  | 	CookieExpire   time.Duration | ||||||
|  | 	CookieHTTPOnly bool | ||||||
|  | 	CookieName     string | ||||||
|  | 	CookiePath     string | ||||||
|  | 	CookieSecret   string | ||||||
|  | 	CookieSecure   bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Save takes a sessions.SessionState and stores the information from it
 | ||||||
|  | // within Cookies set on the HTTP response writer
 | ||||||
|  | func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||||
|  | 	value, err := utils.CookieForSession(ss, s.CookieCipher) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	s.setSessionCookie(rw, req, value) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Load reads sessions.SessionState information from Cookies within the
 | ||||||
|  | // HTTP request object
 | ||||||
|  | func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { | ||||||
|  | 	c, err := loadCookie(req, s.CookieName) | ||||||
|  | 	if err != nil { | ||||||
|  | 		// always http.ErrNoCookie
 | ||||||
|  | 		return nil, fmt.Errorf("Cookie %q not present", s.CookieName) | ||||||
|  | 	} | ||||||
|  | 	val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, errors.New("Cookie Signature not valid") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	session, err := utils.SessionFromCookie(val, s.CookieCipher) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return session, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Clear clears any saved session information by writing a cookie to
 | ||||||
|  | // clear the session
 | ||||||
|  | func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||||
|  | 	var cookies []*http.Cookie | ||||||
|  | 
 | ||||||
|  | 	// matches CookieName, CookieName_<number>
 | ||||||
|  | 	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName)) | ||||||
|  | 
 | ||||||
|  | 	for _, c := range req.Cookies() { | ||||||
|  | 		if cookieNameRegex.MatchString(c.Name) { | ||||||
|  | 			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) | ||||||
|  | 
 | ||||||
|  | 			http.SetCookie(rw, clearCookie) | ||||||
|  | 			cookies = append(cookies, clearCookie) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // setSessionCookie adds the user's session cookie to the response
 | ||||||
|  | func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||||
|  | 	for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { | ||||||
|  | 		http.SetCookie(rw, c) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // makeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||||
|  | // authentication details
 | ||||||
|  | func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { | ||||||
|  | 	if value != "" { | ||||||
|  | 		value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) | ||||||
|  | 	} | ||||||
|  | 	c := s.makeCookie(req, s.CookieName, value, expiration) | ||||||
|  | 	if len(c.Value) > 4096-len(s.CookieName) { | ||||||
|  | 		return splitCookie(c) | ||||||
|  | 	} | ||||||
|  | 	return []*http.Cookie{c} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { | ||||||
|  | 	return cookies.MakeCookie( | ||||||
|  | 		req, | ||||||
|  | 		name, | ||||||
|  | 		value, | ||||||
|  | 		s.CookiePath, | ||||||
|  | 		s.CookieDomain, | ||||||
|  | 		s.CookieHTTPOnly, | ||||||
|  | 		s.CookieSecure, | ||||||
|  | 		expiration, | ||||||
|  | 		time.Now(), | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewCookieSessionStore initialises a new instance of the SessionStore from
 | ||||||
|  | // the configuration given
 | ||||||
|  | func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||||
|  | 	var cipher *cookie.Cipher | ||||||
|  | 	if len(cookieOpts.CookieSecret) > 0 { | ||||||
|  | 		var err error | ||||||
|  | 		cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("unable to create cipher: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return &SessionStore{ | ||||||
|  | 		CookieCipher:   cipher, | ||||||
|  | 		CookieDomain:   cookieOpts.CookieDomain, | ||||||
|  | 		CookieExpire:   cookieOpts.CookieExpire, | ||||||
|  | 		CookieHTTPOnly: cookieOpts.CookieHTTPOnly, | ||||||
|  | 		CookieName:     cookieOpts.CookieName, | ||||||
|  | 		CookiePath:     cookieOpts.CookiePath, | ||||||
|  | 		CookieSecret:   cookieOpts.CookieSecret, | ||||||
|  | 		CookieSecure:   cookieOpts.CookieSecure, | ||||||
|  | 	}, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // splitCookie reads the full cookie generated to store the session and splits
 | ||||||
|  | // it into a slice of cookies which fit within the 4kb cookie limit indexing
 | ||||||
|  | // the cookies from 0
 | ||||||
|  | func splitCookie(c *http.Cookie) []*http.Cookie { | ||||||
|  | 	if len(c.Value) < maxCookieLength { | ||||||
|  | 		return []*http.Cookie{c} | ||||||
|  | 	} | ||||||
|  | 	cookies := []*http.Cookie{} | ||||||
|  | 	valueBytes := []byte(c.Value) | ||||||
|  | 	count := 0 | ||||||
|  | 	for len(valueBytes) > 0 { | ||||||
|  | 		new := copyCookie(c) | ||||||
|  | 		new.Name = fmt.Sprintf("%s_%d", c.Name, count) | ||||||
|  | 		count++ | ||||||
|  | 		if len(valueBytes) < maxCookieLength { | ||||||
|  | 			new.Value = string(valueBytes) | ||||||
|  | 			valueBytes = []byte{} | ||||||
|  | 		} else { | ||||||
|  | 			newValue := valueBytes[:maxCookieLength] | ||||||
|  | 			valueBytes = valueBytes[maxCookieLength:] | ||||||
|  | 			new.Value = string(newValue) | ||||||
|  | 		} | ||||||
|  | 		cookies = append(cookies, new) | ||||||
|  | 	} | ||||||
|  | 	return cookies | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // loadCookie retreieves the sessions state cookie from the http request.
 | ||||||
|  | // If a single cookie is present this will be returned, otherwise it attempts
 | ||||||
|  | // to reconstruct a cookie split up by splitCookie
 | ||||||
|  | func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { | ||||||
|  | 	c, err := req.Cookie(cookieName) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return c, nil | ||||||
|  | 	} | ||||||
|  | 	cookies := []*http.Cookie{} | ||||||
|  | 	err = nil | ||||||
|  | 	count := 0 | ||||||
|  | 	for err == nil { | ||||||
|  | 		var c *http.Cookie | ||||||
|  | 		c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) | ||||||
|  | 		if err == nil { | ||||||
|  | 			cookies = append(cookies, c) | ||||||
|  | 			count++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if len(cookies) == 0 { | ||||||
|  | 		return nil, fmt.Errorf("Could not find cookie %s", cookieName) | ||||||
|  | 	} | ||||||
|  | 	return joinCookies(cookies) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // joinCookies takes a slice of cookies from the request and reconstructs the
 | ||||||
|  | // full session cookie
 | ||||||
|  | func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { | ||||||
|  | 	if len(cookies) == 0 { | ||||||
|  | 		return nil, fmt.Errorf("list of cookies must be > 0") | ||||||
|  | 	} | ||||||
|  | 	if len(cookies) == 1 { | ||||||
|  | 		return cookies[0], nil | ||||||
|  | 	} | ||||||
|  | 	c := copyCookie(cookies[0]) | ||||||
|  | 	for i := 1; i < len(cookies); i++ { | ||||||
|  | 		c.Value += cookies[i].Value | ||||||
|  | 	} | ||||||
|  | 	c.Name = strings.TrimRight(c.Name, "_0") | ||||||
|  | 	return c, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func copyCookie(c *http.Cookie) *http.Cookie { | ||||||
|  | 	return &http.Cookie{ | ||||||
|  | 		Name:       c.Name, | ||||||
|  | 		Value:      c.Value, | ||||||
|  | 		Path:       c.Path, | ||||||
|  | 		Domain:     c.Domain, | ||||||
|  | 		Expires:    c.Expires, | ||||||
|  | 		RawExpires: c.RawExpires, | ||||||
|  | 		MaxAge:     c.MaxAge, | ||||||
|  | 		Secure:     c.Secure, | ||||||
|  | 		HttpOnly:   c.HttpOnly, | ||||||
|  | 		Raw:        c.Raw, | ||||||
|  | 		Unparsed:   c.Unparsed, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,19 @@ | ||||||
|  | package sessions | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // NewSessionStore creates a SessionStore from the provided configuration
 | ||||||
|  | func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { | ||||||
|  | 	switch opts.Type { | ||||||
|  | 	case options.CookieSessionStoreType: | ||||||
|  | 		return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) | ||||||
|  | 	default: | ||||||
|  | 		return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,254 @@ | ||||||
|  | package sessions_test | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | 
 | ||||||
|  | 	. "github.com/onsi/ginkgo" | ||||||
|  | 	. "github.com/onsi/gomega" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/options" | ||||||
|  | 	sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/cookies" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/sessions/cookie" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestSessionStore(t *testing.T) { | ||||||
|  | 	RegisterFailHandler(Fail) | ||||||
|  | 	RunSpecs(t, "SessionStore") | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var _ = Describe("NewSessionStore", func() { | ||||||
|  | 	var opts *options.SessionOptions | ||||||
|  | 	var cookieOpts *options.CookieOptions | ||||||
|  | 
 | ||||||
|  | 	var request *http.Request | ||||||
|  | 	var response *httptest.ResponseRecorder | ||||||
|  | 	var session *sessionsapi.SessionState | ||||||
|  | 	var ss sessionsapi.SessionStore | ||||||
|  | 
 | ||||||
|  | 	CheckCookieOptions := func() { | ||||||
|  | 		Context("the cookies returned", func() { | ||||||
|  | 			var cookies []*http.Cookie | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				cookies = response.Result().Cookies() | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("have the correct name set", func() { | ||||||
|  | 				if len(cookies) == 1 { | ||||||
|  | 					Expect(cookies[0].Name).To(Equal(cookieOpts.CookieName)) | ||||||
|  | 				} else { | ||||||
|  | 					for _, cookie := range cookies { | ||||||
|  | 						Expect(cookie.Name).To(ContainSubstring(cookieOpts.CookieName)) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("have the correct path set", func() { | ||||||
|  | 				for _, cookie := range cookies { | ||||||
|  | 					Expect(cookie.Path).To(Equal(cookieOpts.CookiePath)) | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("have the correct domain set", func() { | ||||||
|  | 				for _, cookie := range cookies { | ||||||
|  | 					Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain)) | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("have the correct HTTPOnly set", func() { | ||||||
|  | 				for _, cookie := range cookies { | ||||||
|  | 					Expect(cookie.HttpOnly).To(Equal(cookieOpts.CookieHTTPOnly)) | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("have the correct secure set", func() { | ||||||
|  | 				for _, cookie := range cookies { | ||||||
|  | 					Expect(cookie.Secure).To(Equal(cookieOpts.CookieSecure)) | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	SessionStoreInterfaceTests := func() { | ||||||
|  | 		Context("when Save is called", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				err := ss.Save(response, request, session) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("sets a `set-cookie` header in the response", func() { | ||||||
|  | 				Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			CheckCookieOptions() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("when Clear is called", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				cookie := cookies.MakeCookie(request, | ||||||
|  | 					cookieOpts.CookieName, | ||||||
|  | 					"foo", | ||||||
|  | 					cookieOpts.CookiePath, | ||||||
|  | 					cookieOpts.CookieDomain, | ||||||
|  | 					cookieOpts.CookieHTTPOnly, | ||||||
|  | 					cookieOpts.CookieSecure, | ||||||
|  | 					cookieOpts.CookieExpire, | ||||||
|  | 					time.Now(), | ||||||
|  | 				) | ||||||
|  | 				request.AddCookie(cookie) | ||||||
|  | 				err := ss.Clear(response, request) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("sets a `set-cookie` header in the response", func() { | ||||||
|  | 				Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			CheckCookieOptions() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("when Load is called", func() { | ||||||
|  | 			var loadedSession *sessionsapi.SessionState | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				req := httptest.NewRequest("GET", "http://example.com/", nil) | ||||||
|  | 				resp := httptest.NewRecorder() | ||||||
|  | 				err := ss.Save(resp, req, session) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				for _, cookie := range resp.Result().Cookies() { | ||||||
|  | 					request.AddCookie(cookie) | ||||||
|  | 				} | ||||||
|  | 				loadedSession, err = ss.Load(request) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("loads a session equal to the original session", func() { | ||||||
|  | 				if cookieOpts.CookieSecret == "" { | ||||||
|  | 					// Only Email and User stored in session when encrypted
 | ||||||
|  | 					Expect(loadedSession.Email).To(Equal(session.Email)) | ||||||
|  | 					Expect(loadedSession.User).To(Equal(session.User)) | ||||||
|  | 				} else { | ||||||
|  | 					// All fields stored in session if encrypted
 | ||||||
|  | 
 | ||||||
|  | 					// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | ||||||
|  | 					l := *loadedSession | ||||||
|  | 					l.ExpiresOn = time.Time{} | ||||||
|  | 					s := *session | ||||||
|  | 					s.ExpiresOn = time.Time{} | ||||||
|  | 					Expect(l).To(Equal(s)) | ||||||
|  | 
 | ||||||
|  | 					// Compare time.Time separately
 | ||||||
|  | 					Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) | ||||||
|  | 				} | ||||||
|  | 			}) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	RunSessionTests := func() { | ||||||
|  | 		Context("with default options", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				var err error | ||||||
|  | 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			SessionStoreInterfaceTests() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with non-default options", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				cookieOpts = &options.CookieOptions{ | ||||||
|  | 					CookieName:     "_cookie_name", | ||||||
|  | 					CookiePath:     "/path", | ||||||
|  | 					CookieExpire:   time.Duration(72) * time.Hour, | ||||||
|  | 					CookieRefresh:  time.Duration(3600), | ||||||
|  | 					CookieSecure:   false, | ||||||
|  | 					CookieHTTPOnly: false, | ||||||
|  | 					CookieDomain:   "example.com", | ||||||
|  | 				} | ||||||
|  | 
 | ||||||
|  | 				var err error | ||||||
|  | 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			SessionStoreInterfaceTests() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("with a cookie-secret set", func() { | ||||||
|  | 			BeforeEach(func() { | ||||||
|  | 				secret := make([]byte, 32) | ||||||
|  | 				_, err := rand.Read(secret) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) | ||||||
|  | 
 | ||||||
|  | 				ss, err = sessions.NewSessionStore(opts, cookieOpts) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			SessionStoreInterfaceTests() | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	BeforeEach(func() { | ||||||
|  | 		ss = nil | ||||||
|  | 		opts = &options.SessionOptions{} | ||||||
|  | 
 | ||||||
|  | 		// Set default options in CookieOptions
 | ||||||
|  | 		cookieOpts = &options.CookieOptions{ | ||||||
|  | 			CookieName:     "_oauth2_proxy", | ||||||
|  | 			CookiePath:     "/", | ||||||
|  | 			CookieExpire:   time.Duration(168) * time.Hour, | ||||||
|  | 			CookieRefresh:  time.Duration(0), | ||||||
|  | 			CookieSecure:   true, | ||||||
|  | 			CookieHTTPOnly: true, | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		session = &sessionsapi.SessionState{ | ||||||
|  | 			AccessToken:  "AccessToken", | ||||||
|  | 			IDToken:      "IDToken", | ||||||
|  | 			ExpiresOn:    time.Now().Add(1 * time.Hour), | ||||||
|  | 			RefreshToken: "RefreshToken", | ||||||
|  | 			Email:        "john.doe@example.com", | ||||||
|  | 			User:         "john.doe", | ||||||
|  | 		} | ||||||
|  | 
 | ||||||
|  | 		request = httptest.NewRequest("GET", "http://example.com/", nil) | ||||||
|  | 		response = httptest.NewRecorder() | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with type 'cookie'", func() { | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			opts.Type = options.CookieSessionStoreType | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("creates a cookie.SessionStore", func() { | ||||||
|  | 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||||
|  | 			Expect(err).NotTo(HaveOccurred()) | ||||||
|  | 			Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{})) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		Context("the cookie.SessionStore", func() { | ||||||
|  | 			RunSessionTests() | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	Context("with an invalid type", func() { | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			opts.Type = "invalid-type" | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		It("returns an error", func() { | ||||||
|  | 			ss, err := sessions.NewSessionStore(opts, cookieOpts) | ||||||
|  | 			Expect(err).To(HaveOccurred()) | ||||||
|  | 			Expect(err.Error()).To(Equal("unknown session store type 'invalid-type'")) | ||||||
|  | 			Expect(ss).To(BeNil()) | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | }) | ||||||
|  | @ -0,0 +1,41 @@ | ||||||
|  | package utils | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"encoding/base64" | ||||||
|  | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // CookieForSession serializes a session state for storage in a cookie
 | ||||||
|  | func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { | ||||||
|  | 	return s.EncodeSessionState(c) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SessionFromCookie deserializes a session from a cookie value
 | ||||||
|  | func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { | ||||||
|  | 	return sessions.DecodeSessionState(v, c) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
 | ||||||
|  | func SecretBytes(secret string) []byte { | ||||||
|  | 	b, err := base64.URLEncoding.DecodeString(addPadding(secret)) | ||||||
|  | 	if err == nil { | ||||||
|  | 		return []byte(addPadding(string(b))) | ||||||
|  | 	} | ||||||
|  | 	return []byte(secret) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func addPadding(secret string) string { | ||||||
|  | 	padding := len(secret) % 4 | ||||||
|  | 	switch padding { | ||||||
|  | 	case 1: | ||||||
|  | 		return secret + "===" | ||||||
|  | 	case 2: | ||||||
|  | 		return secret + "==" | ||||||
|  | 	case 3: | ||||||
|  | 		return secret + "=" | ||||||
|  | 	default: | ||||||
|  | 		return secret | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | @ -9,6 +9,7 @@ import ( | ||||||
| 	"github.com/bitly/go-simplejson" | 	"github.com/bitly/go-simplejson" | ||||||
| 	"github.com/pusher/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // AzureProvider represents an Azure based Identity Provider
 | // AzureProvider represents an Azure based Identity Provider
 | ||||||
|  | @ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 	var email string | 	var email string | ||||||
| 	var err error | 	var err error | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@windows.net", email) | 	assert.Equal(t, "user@windows.net", email) | ||||||
|  | @ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@windows.net", email) | 	assert.Equal(t, "user@windows.net", email) | ||||||
|  | @ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@windows.net", email) | 	assert.Equal(t, "user@windows.net", email) | ||||||
|  | @ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, "type assertion to string failed", err.Error()) | 	assert.Equal(t, "type assertion to string failed", err.Error()) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  | @ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  | @ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testAzureProvider(bURL.Host) | 	p := testAzureProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, "type assertion to string failed", err.Error()) | 	assert.Equal(t, "type assertion to string failed", err.Error()) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // FacebookProvider represents an Facebook based Identity Provider
 | // FacebookProvider represents an Facebook based Identity Provider
 | ||||||
|  | @ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
|  | @ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { | func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||||
| 	return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) | 	return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -11,6 +11,7 @@ import ( | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // GitHubProvider represents an GitHub based Identity Provider
 | // GitHubProvider represents an GitHub based Identity Provider
 | ||||||
|  | @ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 
 | 
 | ||||||
| 	var emails []struct { | 	var emails []struct { | ||||||
| 		Email    string `json:"email"` | 		Email    string `json:"email"` | ||||||
|  | @ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetUserName returns the Account user name
 | // GetUserName returns the Account user name
 | ||||||
| func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { | func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { | ||||||
| 	var user struct { | 	var user struct { | ||||||
| 		Login string `json:"login"` | 		Login string `json:"login"` | ||||||
| 		Email string `json:"email"` | 		Email string `json:"email"` | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitHubProvider(bURL.Host) | 	p := testGitHubProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
|  | @ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitHubProvider(bURL.Host) | 	p := testGitHubProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Empty(t, "", email) | 	assert.Empty(t, "", email) | ||||||
|  | @ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { | ||||||
| 	p := testGitHubProvider(bURL.Host) | 	p := testGitHubProvider(bURL.Host) | ||||||
| 	p.Org = "testorg1" | 	p.Org = "testorg1" | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
|  | @ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
| 	// JSON to fail.
 | 	// JSON to fail.
 | ||||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  | @ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitHubProvider(bURL.Host) | 	p := testGitHubProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  | @ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitHubProvider(bURL.Host) | 	p := testGitHubProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetUserName(session) | 	email, err := p.GetUserName(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "mbland", email) | 	assert.Equal(t, "mbland", email) | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // GitLabProvider represents an GitLab based Identity Provider
 | // GitLabProvider represents an GitLab based Identity Provider
 | ||||||
|  | @ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 
 | 
 | ||||||
| 	req, err := http.NewRequest("GET", | 	req, err := http.NewRequest("GET", | ||||||
| 		p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) | 		p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitLabProvider(bURL.Host) | 	p := testGitLabProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "michael.bland@gsa.gov", email) | 	assert.Equal(t, "michael.bland@gsa.gov", email) | ||||||
|  | @ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
| 	// JSON to fail.
 | 	// JSON to fail.
 | ||||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  | @ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testGitLabProvider(bURL.Host) | 	p := testGitLabProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  |  | ||||||
|  | @ -14,6 +14,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/logger" | 	"github.com/pusher/oauth2_proxy/logger" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| 	"golang.org/x/oauth2/google" | 	"golang.org/x/oauth2/google" | ||||||
| 	admin "google.golang.org/api/admin/directory/v1" | 	admin "google.golang.org/api/admin/directory/v1" | ||||||
|  | @ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		err = errors.New("missing code") | ||||||
| 		return | 		return | ||||||
|  | @ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	s = &SessionState{ | 	s = &sessions.SessionState{ | ||||||
| 		AccessToken:  jsonResponse.AccessToken, | 		AccessToken:  jsonResponse.AccessToken, | ||||||
| 		IDToken:      jsonResponse.IDToken, | 		IDToken:      jsonResponse.IDToken, | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||||
|  | @ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | // RefreshToken to fetch a new ID token if required
 | ||||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct { | ||||||
| 	*ProviderData | 	*ProviderData | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { | func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 	return "", errors.New("not implemented") | 	return "", errors.New("not implemented") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Note that we're testing the internal validateToken() used to implement
 | // Note that we're testing the internal validateToken() used to implement
 | ||||||
| // several Provider's ValidateSessionState() implementations
 | // several Provider's ValidateSessionState() implementations
 | ||||||
| func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { | func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/api" | 	"github.com/pusher/oauth2_proxy/api" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // LinkedInProvider represents an LinkedIn based Identity Provider
 | // LinkedInProvider represents an LinkedIn based Identity Provider
 | ||||||
|  | @ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 	if s.AccessToken == "" { | 	if s.AccessToken == "" { | ||||||
| 		return "", errors.New("missing access token") | 		return "", errors.New("missing access token") | ||||||
| 	} | 	} | ||||||
|  | @ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { | func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||||
| 	return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | 	return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(bURL.Host) | 	p := testLinkedInProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, "user@linkedin.com", email) | 	assert.Equal(t, "user@linkedin.com", email) | ||||||
|  | @ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { | ||||||
| 	// We'll trigger a request failure by using an unexpected access
 | 	// We'll trigger a request failure by using an unexpected access
 | ||||||
| 	// token. Alternatively, we could allow the parsing of the payload as
 | 	// token. Alternatively, we could allow the parsing of the payload as
 | ||||||
| 	// JSON to fail.
 | 	// JSON to fail.
 | ||||||
| 	session := &SessionState{AccessToken: "unexpected_access_token"} | 	session := &sessions.SessionState{AccessToken: "unexpected_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  | @ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { | ||||||
| 	bURL, _ := url.Parse(b.URL) | 	bURL, _ := url.Parse(b.URL) | ||||||
| 	p := testLinkedInProvider(bURL.Host) | 	p := testLinkedInProvider(bURL.Host) | ||||||
| 
 | 
 | ||||||
| 	session := &SessionState{AccessToken: "imaginary_access_token"} | 	session := &sessions.SessionState{AccessToken: "imaginary_access_token"} | ||||||
| 	email, err := p.GetEmailAddress(session) | 	email, err := p.GetEmailAddress(session) | ||||||
| 	assert.NotEqual(t, nil, err) | 	assert.NotEqual(t, nil, err) | ||||||
| 	assert.Equal(t, "", email) | 	assert.Equal(t, "", email) | ||||||
|  |  | ||||||
|  | @ -13,6 +13,7 @@ import ( | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/dgrijalva/jwt-go" | 	"github.com/dgrijalva/jwt-go" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"gopkg.in/square/go-jose.v2" | 	"gopkg.in/square/go-jose.v2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		err = errors.New("missing code") | ||||||
| 		return | 		return | ||||||
|  | @ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Store the data that we found in the session state
 | 	// Store the data that we found in the session state
 | ||||||
| 	s = &SessionState{ | 	s = &sessions.SessionState{ | ||||||
| 		AccessToken: jsonResponse.AccessToken, | 		AccessToken: jsonResponse.AccessToken, | ||||||
| 		IDToken:     jsonResponse.IDToken, | 		IDToken:     jsonResponse.IDToken, | ||||||
| 		ExpiresOn:   time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | 		ExpiresOn:   time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||||
|  |  | ||||||
|  | @ -5,9 +5,9 @@ import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"golang.org/x/oauth2" |  | ||||||
| 
 |  | ||||||
| 	oidc "github.com/coreos/go-oidc" | 	oidc "github.com/coreos/go-oidc" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // OIDCProvider represents an OIDC based Identity Provider
 | // OIDCProvider represents an OIDC based Identity Provider
 | ||||||
|  | @ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
| 		ClientID:     p.ClientID, | 		ClientID:     p.ClientID, | ||||||
|  | @ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | // RefreshToken to fetch a new ID token if required
 | ||||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
|  | @ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { | func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
| 		ClientID:     p.ClientID, | 		ClientID:     p.ClientID, | ||||||
| 		ClientSecret: p.ClientSecret, | 		ClientSecret: p.ClientSecret, | ||||||
|  | @ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { | func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { | ||||||
| 	rawIDToken, ok := token.Extra("id_token").(string) | 	rawIDToken, ok := token.Extra("id_token").(string) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil, fmt.Errorf("token response did not contain an id_token") | 		return nil, fmt.Errorf("token response did not contain an id_token") | ||||||
|  | @ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | ||||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return &SessionState{ | 	return &sessions.SessionState{ | ||||||
| 		AccessToken:  token.AccessToken, | 		AccessToken:  token.AccessToken, | ||||||
| 		IDToken:      rawIDToken, | 		IDToken:      rawIDToken, | ||||||
| 		RefreshToken: token.RefreshToken, | 		RefreshToken: token.RefreshToken, | ||||||
|  | @ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | // ValidateSessionState checks that the session's IDToken is still valid
 | ||||||
| func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { | func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	_, err := p.Verifier.Verify(ctx, s.IDToken) | 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -10,10 +10,11 @@ import ( | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 
 | 
 | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Redeem provides a default implementation of the OAuth2 token redemption process
 | // Redeem provides a default implementation of the OAuth2 token redemption process
 | ||||||
| func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { | func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		err = errors.New("missing code") | 		err = errors.New("missing code") | ||||||
| 		return | 		return | ||||||
|  | @ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er | ||||||
| 	} | 	} | ||||||
| 	err = json.Unmarshal(body, &jsonResponse) | 	err = json.Unmarshal(body, &jsonResponse) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		s = &SessionState{ | 		s = &sessions.SessionState{ | ||||||
| 			AccessToken: jsonResponse.AccessToken, | 			AccessToken: jsonResponse.AccessToken, | ||||||
| 		} | 		} | ||||||
| 		return | 		return | ||||||
|  | @ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	if a := v.Get("access_token"); a != "" { | 	if a := v.Get("access_token"); a != "" { | ||||||
| 		s = &SessionState{AccessToken: a} | 		s = &sessions.SessionState{AccessToken: a} | ||||||
| 	} else { | 	} else { | ||||||
| 		err = fmt.Errorf("no access token found %s", body) | 		err = fmt.Errorf("no access token found %s", body) | ||||||
| 	} | 	} | ||||||
|  | @ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CookieForSession serializes a session state for storage in a cookie
 | // CookieForSession serializes a session state for storage in a cookie
 | ||||||
| func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { | func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { | ||||||
| 	return s.EncodeSessionState(c) | 	return s.EncodeSessionState(c) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // SessionFromCookie deserializes a session from a cookie value
 | // SessionFromCookie deserializes a session from a cookie value
 | ||||||
| func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { | func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { | ||||||
| 	return DecodeSessionState(v, c) | 	return sessions.DecodeSessionState(v, c) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetEmailAddress returns the Account email address
 | // GetEmailAddress returns the Account email address
 | ||||||
| func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { | func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) { | ||||||
| 	return "", errors.New("not implemented") | 	return "", errors.New("not implemented") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetUserName returns the Account username
 | // GetUserName returns the Account username
 | ||||||
| func (p *ProviderData) GetUserName(s *SessionState) (string, error) { | func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { | ||||||
| 	return "", errors.New("not implemented") | 	return "", errors.New("not implemented") | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSessionState validates the AccessToken
 | ||||||
| func (p *ProviderData) ValidateSessionState(s *SessionState) bool { | func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { | ||||||
| 	return validateToken(p, s.AccessToken, nil) | 	return validateToken(p, s.AccessToken, nil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded should refresh the user's session if required and
 | // RefreshSessionIfNeeded should refresh the user's session if required and
 | ||||||
| // do nothing if a refresh is not required
 | // do nothing if a refresh is not required
 | ||||||
| func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { | ||||||
| 	return false, nil | 	return false, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,12 +4,13 @@ import ( | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| 	"github.com/stretchr/testify/assert" | 	"github.com/stretchr/testify/assert" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestRefresh(t *testing.T) { | func TestRefresh(t *testing.T) { | ||||||
| 	p := &ProviderData{} | 	p := &ProviderData{} | ||||||
| 	refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ | 	refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{ | ||||||
| 		ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), | 		ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), | ||||||
| 	}) | 	}) | ||||||
| 	assert.Equal(t, false, refreshed) | 	assert.Equal(t, false, refreshed) | ||||||
|  |  | ||||||
|  | @ -2,20 +2,21 @@ package providers | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"github.com/pusher/oauth2_proxy/cookie" | 	"github.com/pusher/oauth2_proxy/cookie" | ||||||
|  | 	"github.com/pusher/oauth2_proxy/pkg/apis/sessions" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // Provider represents an upstream identity provider implementation
 | // Provider represents an upstream identity provider implementation
 | ||||||
| type Provider interface { | type Provider interface { | ||||||
| 	Data() *ProviderData | 	Data() *ProviderData | ||||||
| 	GetEmailAddress(*SessionState) (string, error) | 	GetEmailAddress(*sessions.SessionState) (string, error) | ||||||
| 	GetUserName(*SessionState) (string, error) | 	GetUserName(*sessions.SessionState) (string, error) | ||||||
| 	Redeem(string, string) (*SessionState, error) | 	Redeem(string, string) (*sessions.SessionState, error) | ||||||
| 	ValidateGroup(string) bool | 	ValidateGroup(string) bool | ||||||
| 	ValidateSessionState(*SessionState) bool | 	ValidateSessionState(*sessions.SessionState) bool | ||||||
| 	GetLoginURL(redirectURI, finalRedirect string) string | 	GetLoginURL(redirectURI, finalRedirect string) string | ||||||
| 	RefreshSessionIfNeeded(*SessionState) (bool, error) | 	RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) | ||||||
| 	SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) | 	SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error) | ||||||
| 	CookieForSession(*SessionState, *cookie.Cipher) (string, error) | 	CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // New provides a new Provider based on the configured provider string
 | // New provides a new Provider based on the configured provider string
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue