Merge pull request #688 from oauth2-proxy/session-middlewares
Refactor session loading to make use of middleware pattern
This commit is contained in:
		
						commit
						f141f7cea0
					
				|  | @ -11,6 +11,7 @@ | |||
| 
 | ||||
| ## Changes since v6.0.0 | ||||
| 
 | ||||
| - [#688](https://github.com/oauth2-proxy/oauth2-proxy/pull/688) Refactor session loading to make use of middleware pattern (@JoelSpeed) | ||||
| - [#593](https://github.com/oauth2-proxy/oauth2-proxy/pull/593) Integrate upstream package with OAuth2 Proxy (@JoelSpeed) | ||||
| - [#687](https://github.com/oauth2-proxy/oauth2-proxy/pull/687) Refactor HTPasswd Validator (@JoelSpeed) | ||||
| - [#624](https://github.com/oauth2-proxy/oauth2-proxy/pull/624) Allow stripping authentication headers from whitelisted requests with `--skip-auth-strip-headers` (@NickMeves) | ||||
|  |  | |||
							
								
								
									
										222
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										222
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -15,7 +15,9 @@ import ( | |||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/justinas/alice" | ||||
| 	ipapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/ip" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" | ||||
|  | @ -23,6 +25,7 @@ import ( | |||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/ip" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/upstream" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/providers" | ||||
|  | @ -98,6 +101,8 @@ type OAuthProxy struct { | |||
| 	trustedIPs              *ip.NetSet | ||||
| 	Banner                  string | ||||
| 	Footer                  string | ||||
| 
 | ||||
| 	sessionChain alice.Chain | ||||
| } | ||||
| 
 | ||||
| // NewOAuthProxy creates a new instance of OAuthProxy from the options provided
 | ||||
|  | @ -156,6 +161,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	sessionChain := buildSessionChain(opts, sessionStore, basicAuthValidator) | ||||
| 
 | ||||
| 	return &OAuthProxy{ | ||||
| 		CookieName:     opts.Cookie.Name, | ||||
| 		CSRFCookieName: fmt.Sprintf("%v_%v", opts.Cookie.Name, "csrf"), | ||||
|  | @ -209,9 +216,45 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr | |||
| 
 | ||||
| 		basicAuthValidator:  basicAuthValidator, | ||||
| 		displayHtpasswdForm: basicAuthValidator != nil, | ||||
| 		sessionChain:        sessionChain, | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionStore, validator basic.Validator) alice.Chain { | ||||
| 	chain := alice.New(middleware.NewScope()) | ||||
| 
 | ||||
| 	if opts.SkipJwtBearerTokens { | ||||
| 		sessionLoaders := []middlewareapi.TokenToSessionLoader{} | ||||
| 		if opts.GetOIDCVerifier() != nil { | ||||
| 			sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ | ||||
| 				Verifier:       opts.GetOIDCVerifier(), | ||||
| 				TokenToSession: opts.GetProvider().CreateSessionStateFromBearerToken, | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
| 		for _, verifier := range opts.GetJWTBearerVerifiers() { | ||||
| 			sessionLoaders = append(sessionLoaders, middlewareapi.TokenToSessionLoader{ | ||||
| 				Verifier: verifier, | ||||
| 			}) | ||||
| 		} | ||||
| 
 | ||||
| 		chain = chain.Append(middleware.NewJwtSessionLoader(sessionLoaders)) | ||||
| 	} | ||||
| 
 | ||||
| 	if validator != nil { | ||||
| 		chain = chain.Append(middleware.NewBasicAuthSessionLoader(validator)) | ||||
| 	} | ||||
| 
 | ||||
| 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | ||||
| 		SessionStore:           sessionStore, | ||||
| 		RefreshPeriod:          opts.Cookie.Refresh, | ||||
| 		RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, | ||||
| 		ValidateSessionState:   opts.GetProvider().ValidateSessionState, | ||||
| 	})) | ||||
| 
 | ||||
| 	return chain | ||||
| } | ||||
| 
 | ||||
| // GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
 | ||||
| // redirect clients to once authenticated
 | ||||
| func (p *OAuthProxy) GetRedirectURI(host string) string { | ||||
|  | @ -780,86 +823,20 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | |||
| // Set-Cookie headers may be set on the response as a side-effect of calling this method.
 | ||||
| func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	var session *sessionsapi.SessionState | ||||
| 	var err error | ||||
| 	var saveSession, clearSession, revalidated bool | ||||
| 
 | ||||
| 	if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" { | ||||
| 		session, err = p.GetJwtSession(req) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error retrieving session from token in Authorization header: %s", err) | ||||
| 		} | ||||
| 		if session != nil { | ||||
| 			saveSession = false | ||||
| 		} | ||||
| 	} | ||||
| 	getSession := p.sessionChain.Then(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		session = middleware.GetRequestScope(req).Session | ||||
| 	})) | ||||
| 	getSession.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 	remoteAddr := ip.GetClientString(p.realClientIPParser, req, true) | ||||
| 	if session == nil { | ||||
| 		session, err = p.LoadCookiedSession(req) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error loading cookied session: %s", err) | ||||
| 		} | ||||
| 
 | ||||
| 		if session != nil { | ||||
| 			if session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { | ||||
| 				logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) | ||||
| 				saveSession = true | ||||
| 			} | ||||
| 
 | ||||
| 			if ok, err := p.provider.RefreshSessionIfNeeded(req.Context(), session); err != nil { | ||||
| 				logger.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) | ||||
| 				clearSession = true | ||||
| 				session = nil | ||||
| 			} else if ok { | ||||
| 				saveSession = true | ||||
| 				revalidated = true | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if session != nil && session.IsExpired() { | ||||
| 		logger.Printf("Removing session: token expired %s", session) | ||||
| 		session = nil | ||||
| 		saveSession = false | ||||
| 		clearSession = true | ||||
| 	} | ||||
| 
 | ||||
| 	if saveSession && !revalidated && session != nil && session.AccessToken != "" { | ||||
| 		if !p.provider.ValidateSessionState(req.Context(), session) { | ||||
| 			logger.Printf("Removing session: error validating %s", session) | ||||
| 			saveSession = false | ||||
| 			session = nil | ||||
| 			clearSession = true | ||||
| 		} | ||||
| 		return nil, ErrNeedsLogin | ||||
| 	} | ||||
| 
 | ||||
| 	if session != nil && session.Email != "" && !p.Validator(session.Email) { | ||||
| 		logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session: removing session %s", session) | ||||
| 		session = nil | ||||
| 		saveSession = false | ||||
| 		clearSession = true | ||||
| 	} | ||||
| 
 | ||||
| 	if saveSession && session != nil { | ||||
| 		err = p.SaveSession(rw, req, session) | ||||
| 		if err != nil { | ||||
| 			logger.PrintAuthf(session.Email, req, logger.AuthError, "Save session error %s", err) | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if clearSession { | ||||
| 		// Invalid session, clear it
 | ||||
| 		p.ClearSessionCookie(rw, req) | ||||
| 	} | ||||
| 
 | ||||
| 	if session == nil { | ||||
| 		session, err = p.CheckBasicAuth(req) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error during basic auth validation: %s", err) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	if session == nil { | ||||
| 		return nil, ErrNeedsLogin | ||||
| 	} | ||||
| 
 | ||||
|  | @ -997,36 +974,6 @@ func (p *OAuthProxy) stripAuthHeaders(req *http.Request) { | |||
| 	} | ||||
| } | ||||
| 
 | ||||
| // CheckBasicAuth checks the requests Authorization header for basic auth
 | ||||
| // credentials and authenticates these against the proxies HtpasswdFile
 | ||||
| func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	if p.basicAuthValidator == nil { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	auth := req.Header.Get("Authorization") | ||||
| 	if auth == "" { | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 	s := strings.SplitN(auth, " ", 2) | ||||
| 	if len(s) != 2 || s[0] != "Basic" { | ||||
| 		return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) | ||||
| 	} | ||||
| 	b, err := b64.StdEncoding.DecodeString(s[1]) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	pair := strings.SplitN(string(b), ":", 2) | ||||
| 	if len(pair) != 2 { | ||||
| 		return nil, fmt.Errorf("invalid format %s", b) | ||||
| 	} | ||||
| 	if p.basicAuthValidator.Validate(pair[0], pair[1]) { | ||||
| 		logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||
| 		return &sessionsapi.SessionState{User: pair[0]}, nil | ||||
| 	} | ||||
| 	logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| // isAjax checks if a request is an ajax request
 | ||||
| func isAjax(req *http.Request) bool { | ||||
| 	acceptValues := req.Header.Values("Accept") | ||||
|  | @ -1044,74 +991,3 @@ func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { | |||
| 	rw.Header().Set("Content-Type", applicationJSON) | ||||
| 	rw.WriteHeader(code) | ||||
| } | ||||
| 
 | ||||
| // GetJwtSession loads a session based on a JWT token in the authorization header.
 | ||||
| // (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers)
 | ||||
| func (p *OAuthProxy) GetJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	rawBearerToken, err := p.findBearerToken(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// If we are using an oidc provider, go ahead and try that provider first with its Verifier
 | ||||
| 	// and Bearer Token -> Session converter
 | ||||
| 	if p.mainJwtBearerVerifier != nil { | ||||
| 		bearerToken, err := p.mainJwtBearerVerifier.Verify(req.Context(), rawBearerToken) | ||||
| 		if err == nil { | ||||
| 			return p.provider.CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	// Otherwise, attempt to verify against the extra JWT issuers and use a more generic
 | ||||
| 	// Bearer Token -> Session converter
 | ||||
| 	for _, verifier := range p.extraJwtBearerVerifiers { | ||||
| 		bearerToken, err := verifier.Verify(req.Context(), rawBearerToken) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
| 
 | ||||
| 		return (*providers.ProviderData)(nil).CreateSessionStateFromBearerToken(req.Context(), rawBearerToken, bearerToken) | ||||
| 	} | ||||
| 	return nil, fmt.Errorf("unable to verify jwt token %s", req.Header.Get("Authorization")) | ||||
| } | ||||
| 
 | ||||
| // findBearerToken finds a valid JWT token from the Authorization header of a given request.
 | ||||
| func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) { | ||||
| 	auth := req.Header.Get("Authorization") | ||||
| 	s := strings.SplitN(auth, " ", 2) | ||||
| 	if len(s) != 2 { | ||||
| 		return "", fmt.Errorf("invalid authorization header %s", auth) | ||||
| 	} | ||||
| 	jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`) | ||||
| 	var rawBearerToken string | ||||
| 	if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) { | ||||
| 		rawBearerToken = s[1] | ||||
| 	} else if s[0] == "Basic" { | ||||
| 		// Check if we have a Bearer token masquerading in Basic
 | ||||
| 		b, err := b64.StdEncoding.DecodeString(s[1]) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 		pair := strings.SplitN(string(b), ":", 2) | ||||
| 		if len(pair) != 2 { | ||||
| 			return "", fmt.Errorf("invalid format %s", b) | ||||
| 		} | ||||
| 		user, password := pair[0], pair[1] | ||||
| 
 | ||||
| 		// check user, user+password, or just password for a token
 | ||||
| 		if jwtRegex.MatchString(user) { | ||||
| 			// Support blank passwords or magic `x-oauth-basic` passwords - nothing else
 | ||||
| 			if password == "" || password == "x-oauth-basic" { | ||||
| 				rawBearerToken = user | ||||
| 			} | ||||
| 		} else if jwtRegex.MatchString(password) { | ||||
| 			// support passwords and ignore user
 | ||||
| 			rawBearerToken = password | ||||
| 		} | ||||
| 	} | ||||
| 	if rawBearerToken == "" { | ||||
| 		return "", fmt.Errorf("no valid bearer token found in authorization header") | ||||
| 	} | ||||
| 
 | ||||
| 	return rawBearerToken, nil | ||||
| } | ||||
|  |  | |||
|  | @ -1889,7 +1889,7 @@ func TestGetJwtSession(t *testing.T) { | |||
| 
 | ||||
| 	// Bearer
 | ||||
| 	expires := time.Unix(1912151821, 0) | ||||
| 	session, err := test.proxy.GetJwtSession(test.req) | ||||
| 	session, err := test.proxy.getAuthenticatedSession(test.rw, test.req) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, session.User, "1234567890") | ||||
| 	assert.Equal(t, session.Email, "john@example.com") | ||||
|  | @ -1912,70 +1912,6 @@ func TestGetJwtSession(t *testing.T) { | |||
| 	assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") | ||||
| } | ||||
| 
 | ||||
| func TestFindJwtBearerToken(t *testing.T) { | ||||
| 	p := OAuthProxy{CookieName: "oauth2", CookieDomains: []string{"abc"}} | ||||
| 	getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}} | ||||
| 
 | ||||
| 	validToken := "eyJfoobar.eyJfoobar.12345asdf" | ||||
| 	var token string | ||||
| 
 | ||||
| 	// Bearer
 | ||||
| 	getReq.Header = map[string][]string{ | ||||
| 		"Authorization": {fmt.Sprintf("Bearer %s", validToken)}, | ||||
| 	} | ||||
| 
 | ||||
| 	token, err := p.findBearerToken(getReq) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, validToken, token) | ||||
| 
 | ||||
| 	// Basic - no password
 | ||||
| 	getReq.SetBasicAuth(token, "") | ||||
| 	token, err = p.findBearerToken(getReq) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, validToken, token) | ||||
| 
 | ||||
| 	// Basic - sentinel password
 | ||||
| 	getReq.SetBasicAuth(token, "x-oauth-basic") | ||||
| 	token, err = p.findBearerToken(getReq) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, validToken, token) | ||||
| 
 | ||||
| 	// Basic - any username, password matching jwt pattern
 | ||||
| 	getReq.SetBasicAuth("any-username-you-could-wish-for", token) | ||||
| 	token, err = p.findBearerToken(getReq) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Equal(t, validToken, token) | ||||
| 
 | ||||
| 	failures := []string{ | ||||
| 		// Too many parts
 | ||||
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA", | ||||
| 		// Not enough parts
 | ||||
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA", | ||||
| 		// Invalid encrypted key
 | ||||
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA", | ||||
| 		// Invalid IV
 | ||||
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA", | ||||
| 		// Invalid ciphertext
 | ||||
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA", | ||||
| 		// Invalid tag
 | ||||
| 		"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////", | ||||
| 		// Invalid header
 | ||||
| 		"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA", | ||||
| 		// Invalid header
 | ||||
| 		"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA", | ||||
| 		// Missing alc/enc params
 | ||||
| 		"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA", | ||||
| 	} | ||||
| 
 | ||||
| 	for _, failure := range failures { | ||||
| 		getReq.Header = map[string][]string{ | ||||
| 			"Authorization": {fmt.Sprintf("Bearer %s", failure)}, | ||||
| 		} | ||||
| 		_, err := p.findBearerToken(getReq) | ||||
| 		assert.Error(t, err) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func Test_prepareNoCache(t *testing.T) { | ||||
| 	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		prepareNoCache(w) | ||||
|  |  | |||
|  | @ -0,0 +1,24 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // RequestScope contains information regarding the request that is being made.
 | ||||
| // The RequestScope is used to pass information between different middlewares
 | ||||
| // within the chain.
 | ||||
| type RequestScope struct { | ||||
| 	// Session details the authenticated users information (if it exists).
 | ||||
| 	Session *sessions.SessionState | ||||
| 
 | ||||
| 	// SaveSession indicates whether the session storage should attempt to save
 | ||||
| 	// the session or not.
 | ||||
| 	SaveSession bool | ||||
| 
 | ||||
| 	// ClearSession indicates whether the user should be logged out or not.
 | ||||
| 	ClearSession bool | ||||
| 
 | ||||
| 	// SessionRevalidated indicates whether the session has been revalidated since
 | ||||
| 	// it was loaded or not.
 | ||||
| 	SessionRevalidated bool | ||||
| } | ||||
|  | @ -0,0 +1,24 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| ) | ||||
| 
 | ||||
| // TokenToSessionFunc takes a rawIDToken and an idToken and converts it into a
 | ||||
| // SessionState.
 | ||||
| type TokenToSessionFunc func(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) | ||||
| 
 | ||||
| // TokenToSessionLoader pairs a token verifier with the correct converter function
 | ||||
| // to convert the ID Token to a SessionState.
 | ||||
| type TokenToSessionLoader struct { | ||||
| 	// Verfier is used to verify that the ID Token was signed by the claimed issuer
 | ||||
| 	// and that the token has not been tampered with.
 | ||||
| 	Verifier *oidc.IDTokenVerifier | ||||
| 
 | ||||
| 	// TokenToSession converts a rawIDToken and an idToken to a SessionState.
 | ||||
| 	// (Optional) If not set a default basic implementation is used.
 | ||||
| 	TokenToSession TokenToSessionFunc | ||||
| } | ||||
|  | @ -0,0 +1,88 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/authentication/basic" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| func NewBasicAuthSessionLoader(validator basic.Validator) alice.Constructor { | ||||
| 	return func(next http.Handler) http.Handler { | ||||
| 		return loadBasicAuthSession(validator, next) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // loadBasicAuthSession attmepts to load a session from basic auth credentials
 | ||||
| // stored in an Authorization header within the request.
 | ||||
| // If no authorization header is found, or the header is invalid, no session
 | ||||
| // will be loaded and the request will be passed to the next handler.
 | ||||
| // If a session was loaded by a previous handler, it will not be replaced.
 | ||||
| func loadBasicAuthSession(validator basic.Validator, next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := GetRequestScope(req) | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		if scope.Session != nil { | ||||
| 			// The session was already loaded, pass to the next handler
 | ||||
| 			next.ServeHTTP(rw, req) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		session, err := getBasicSession(validator, req) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error retrieving session from token in Authorization header: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		// Add the session to the scope if it was found
 | ||||
| 		scope.Session = session | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // getBasicSession attempts to load a basic session from the request.
 | ||||
| // If the credentials in the request exist within the htpasswdMap,
 | ||||
| // a new session will be created.
 | ||||
| func getBasicSession(validator basic.Validator, req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	auth := req.Header.Get("Authorization") | ||||
| 	if auth == "" { | ||||
| 		// No auth header provided, so don't attempt to load a session
 | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	user, password, err := findBasicCredentialsFromHeader(auth) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	if validator.Validate(user, password) { | ||||
| 		logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") | ||||
| 		return &sessionsapi.SessionState{User: user}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| // findBasicCredentialsFromHeader finds basic auth credneitals from the
 | ||||
| // Authorization header of a given request.
 | ||||
| func findBasicCredentialsFromHeader(header string) (string, string, error) { | ||||
| 	tokenType, token, err := splitAuthHeader(header) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if tokenType != "Basic" { | ||||
| 		return "", "", fmt.Errorf("invalid Authorization header: %q", header) | ||||
| 	} | ||||
| 
 | ||||
| 	user, password, err := getBasicAuthCredentials(token) | ||||
| 	if err != nil { | ||||
| 		return "", "", fmt.Errorf("error decoding basic auth credentials: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	return user, password, nil | ||||
| } | ||||
|  | @ -0,0 +1,132 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	adminUser     = "admin" | ||||
| 	adminPassword = "Adm1n1str$t0r" | ||||
| 	user1         = "user1" | ||||
| 	user1Password = "UsErOn3P455" | ||||
| 	user2         = "user2" | ||||
| 	user2Password = "us3r2P455W0Rd!" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Basic Auth Session Suite", func() { | ||||
| 	Context("BasicAuthSessionLoader", func() { | ||||
| 
 | ||||
| 		type basicAuthSessionLoaderTableInput struct { | ||||
| 			authorizationHeader string | ||||
| 			existingSession     *sessionsapi.SessionState | ||||
| 			expectedSession     *sessionsapi.SessionState | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with an authorization header", | ||||
| 			func(in basicAuthSessionLoaderTableInput) { | ||||
| 				scope := &middlewareapi.RequestScope{ | ||||
| 					Session: in.existingSession, | ||||
| 				} | ||||
| 
 | ||||
| 				// Set up the request with the authorization header and a request scope
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				req.Header.Set("Authorization", in.authorizationHeader) | ||||
| 				contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 				req = req.WithContext(contextWithScope) | ||||
| 
 | ||||
| 				rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 				validator := fakeBasicValidator{ | ||||
| 					users: map[string]string{ | ||||
| 						adminUser: adminPassword, | ||||
| 						user1:     user1Password, | ||||
| 						user2:     user2Password, | ||||
| 					}, | ||||
| 				} | ||||
| 
 | ||||
| 				// Create the handler with a next handler that will capture the session
 | ||||
| 				// from the scope
 | ||||
| 				var gotSession *sessionsapi.SessionState | ||||
| 				handler := NewBasicAuthSessionLoader(validator)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 					gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session | ||||
| 				})) | ||||
| 				handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 				Expect(gotSession).To(Equal(in.expectedSession)) | ||||
| 			}, | ||||
| 			Entry("<no value>", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("abcdef", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "abcdef", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("abcdef (with existing session)", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "abcdef", | ||||
| 				existingSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 			}), | ||||
| 			Entry("Bearer <password>", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Bearer %s", adminPassword), | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Basic <password>", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Basic %s", adminPassword), | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(:<password>) (with existing session)", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "Basic OlVzRXJPbjNQNDU1", | ||||
| 				existingSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(user1:<user1Password>)", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "Basic dXNlcjE6VXNFck9uM1A0NTU=", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "user1"}, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(user2:<user1Password>)", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "Basic dXNlcjI6VXNFck9uM1A0NTU=", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(user2:<user2Password>)", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "Basic dXNlcjI6dXMzcjJQNDU1VzBSZCE=", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "user2"}, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(admin:<adminPassword>)", basicAuthSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "Basic YWRtaW46QWRtMW4xc3RyJHQwcg==", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "admin"}, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
| 
 | ||||
| type fakeBasicValidator struct { | ||||
| 	users map[string]string | ||||
| } | ||||
| 
 | ||||
| func (f fakeBasicValidator) Validate(user, password string) bool { | ||||
| 	if f.users == nil { | ||||
| 		return false | ||||
| 	} | ||||
| 	if realPassword, ok := f.users[user]; ok { | ||||
| 		return realPassword == password | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | @ -0,0 +1,168 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"regexp" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| const jwtRegexFormat = `^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$` | ||||
| 
 | ||||
| func NewJwtSessionLoader(sessionLoaders []middlewareapi.TokenToSessionLoader) alice.Constructor { | ||||
| 	for i, loader := range sessionLoaders { | ||||
| 		if loader.TokenToSession == nil { | ||||
| 			sessionLoaders[i] = middlewareapi.TokenToSessionLoader{ | ||||
| 				Verifier:       loader.Verifier, | ||||
| 				TokenToSession: createSessionStateFromBearerToken, | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	js := &jwtSessionLoader{ | ||||
| 		jwtRegex:       regexp.MustCompile(jwtRegexFormat), | ||||
| 		sessionLoaders: sessionLoaders, | ||||
| 	} | ||||
| 	return js.loadSession | ||||
| } | ||||
| 
 | ||||
| // jwtSessionLoader is responsible for loading sessions from JWTs in
 | ||||
| // Authorization headers.
 | ||||
| type jwtSessionLoader struct { | ||||
| 	jwtRegex       *regexp.Regexp | ||||
| 	sessionLoaders []middlewareapi.TokenToSessionLoader | ||||
| } | ||||
| 
 | ||||
| // loadSession attempts to load a session from a JWT stored in an Authorization
 | ||||
| // header within the request.
 | ||||
| // If no authorization header is found, or the header is invalid, no session
 | ||||
| // will be loaded and the request will be passed to the next handler.
 | ||||
| // If a session was loaded by a previous handler, it will not be replaced.
 | ||||
| func (j *jwtSessionLoader) loadSession(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := GetRequestScope(req) | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		if scope.Session != nil { | ||||
| 			// The session was already loaded, pass to the next handler
 | ||||
| 			next.ServeHTTP(rw, req) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		session, err := j.getJwtSession(req) | ||||
| 		if err != nil { | ||||
| 			logger.Printf("Error retrieving session from token in Authorization header: %v", err) | ||||
| 		} | ||||
| 
 | ||||
| 		// Add the session to the scope if it was found
 | ||||
| 		scope.Session = session | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // getJwtSession loads a session based on a JWT token in the authorization header.
 | ||||
| // (see the config options skip-jwt-bearer-tokens and extra-jwt-issuers)
 | ||||
| func (j *jwtSessionLoader) getJwtSession(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	auth := req.Header.Get("Authorization") | ||||
| 	if auth == "" { | ||||
| 		// No auth header provided, so don't attempt to load a session
 | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	rawBearerToken, err := j.findBearerTokenFromHeader(auth) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, loader := range j.sessionLoaders { | ||||
| 		bearerToken, err := loader.Verifier.Verify(req.Context(), rawBearerToken) | ||||
| 		if err == nil { | ||||
| 			// The token was verified, convert it to a session
 | ||||
| 			return loader.TokenToSession(req.Context(), rawBearerToken, bearerToken) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return nil, fmt.Errorf("unable to verify jwt token: %q", req.Header.Get("Authorization")) | ||||
| } | ||||
| 
 | ||||
| // findBearerTokenFromHeader finds a valid JWT token from the Authorization header of a given request.
 | ||||
| func (j *jwtSessionLoader) findBearerTokenFromHeader(header string) (string, error) { | ||||
| 	tokenType, token, err := splitAuthHeader(header) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	if tokenType == "Bearer" && j.jwtRegex.MatchString(token) { | ||||
| 		// Found a JWT as a bearer token
 | ||||
| 		return token, nil | ||||
| 	} | ||||
| 
 | ||||
| 	if tokenType == "Basic" { | ||||
| 		// Check if we have a Bearer token masquerading in Basic
 | ||||
| 		return j.getBasicToken(token) | ||||
| 	} | ||||
| 
 | ||||
| 	return "", fmt.Errorf("no valid bearer token found in authorization header") | ||||
| } | ||||
| 
 | ||||
| // getBasicToken tries to extract a token from the basic value provided.
 | ||||
| func (j *jwtSessionLoader) getBasicToken(token string) (string, error) { | ||||
| 	user, password, err := getBasicAuthCredentials(token) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	// check user, user+password, or just password for a token
 | ||||
| 	if j.jwtRegex.MatchString(user) { | ||||
| 		// Support blank passwords or magic `x-oauth-basic` passwords - nothing else
 | ||||
| 		if password == "" || password == "x-oauth-basic" { | ||||
| 			return user, nil | ||||
| 		} | ||||
| 	} else if j.jwtRegex.MatchString(password) { | ||||
| 		// support passwords and ignore user
 | ||||
| 		return password, nil | ||||
| 	} | ||||
| 
 | ||||
| 	return "", fmt.Errorf("invalid basic auth token found in authorization header") | ||||
| } | ||||
| 
 | ||||
| // createSessionStateFromBearerToken is a default implementation for converting
 | ||||
| // a JWT into a session state.
 | ||||
| func createSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessionsapi.SessionState, error) { | ||||
| 	var claims struct { | ||||
| 		Subject           string `json:"sub"` | ||||
| 		Email             string `json:"email"` | ||||
| 		Verified          *bool  `json:"email_verified"` | ||||
| 		PreferredUsername string `json:"preferred_username"` | ||||
| 	} | ||||
| 
 | ||||
| 	if err := idToken.Claims(&claims); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.Email == "" { | ||||
| 		claims.Email = claims.Subject | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.Verified != nil && !*claims.Verified { | ||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	newSession := &sessionsapi.SessionState{ | ||||
| 		Email:             claims.Email, | ||||
| 		User:              claims.Subject, | ||||
| 		PreferredUsername: claims.PreferredUsername, | ||||
| 		AccessToken:       rawIDToken, | ||||
| 		IDToken:           rawIDToken, | ||||
| 		RefreshToken:      "", | ||||
| 		ExpiresOn:         &idToken.Expiry, | ||||
| 	} | ||||
| 
 | ||||
| 	return newSession, nil | ||||
| } | ||||
|  | @ -0,0 +1,492 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/rsa" | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| type noOpKeySet struct { | ||||
| } | ||||
| 
 | ||||
| func (noOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { | ||||
| 	splitStrings := strings.Split(jwt, ".") | ||||
| 	payloadString := splitStrings[1] | ||||
| 	return base64.RawURLEncoding.DecodeString(payloadString) | ||||
| } | ||||
| 
 | ||||
| var _ = Describe("JWT Session Suite", func() { | ||||
| 	/* token payload: | ||||
| 	{ | ||||
| 	  "sub": "1234567890", | ||||
| 	  "aud": "https://test.myapp.com", | ||||
| 	  "name": "John Doe", | ||||
| 	  "email": "john@example.com", | ||||
| 	  "iss": "https://issuer.example.com", | ||||
| 	  "iat": 1553691215, | ||||
| 	  "exp": 1912151821 | ||||
| 	} | ||||
| 	*/ | ||||
| 	const verifiedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + | ||||
| 		"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + | ||||
| 		"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + | ||||
| 		"E1LCJleHAiOjE5MTIxNTE4MjF9." + | ||||
| 		"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + | ||||
| 		"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" | ||||
| 
 | ||||
| 	const verifiedTokenXOAuthBasicBase64 = `ZXlKaGJHY2lPaUpTVXpJMU5pSXNJblI1Y0NJNklrcFhWQ0o5LmV5SnpkV0lpT2lJeE1qTTBOVFkz | ||||
| T0Rrd0lpd2lZWFZrSWpvaWFIUjBjSE02THk5MFpYTjBMbTE1WVhCd0xtTnZiU0lzSW01aGJXVWlP | ||||
| aUpLYjJodUlFUnZaU0lzSW1WdFlXbHNJam9pYW05b2JrQmxlR0Z0Y0d4bExtTnZiU0lzSW1semN5 | ||||
| STZJbWgwZEhCek9pOHZhWE56ZFdWeUxtVjRZVzF3YkdVdVkyOXRJaXdpYVdGMElqb3hOVFV6Tmpr | ||||
| eE1qRTFMQ0psZUhBaU9qRTVNVEl4TlRFNE1qRjkuckxWeXpPbkVsZFVxX3BOa2ZhLVdpVjhUVkpZ | ||||
| V3laQ2FNMkFtX3VvOEZHZzExekQ3bC1xbXozeDFzZVR2cXBINlkwVHkwMGZtdjZkSm5HbkM4V01u | ||||
| UFhRaW9kUlRmaEJTZU9LWk11MEhrTUQyc2c1MnpsS2tiZkxUTzZpYzVWbmJWZ3dqanJCOGFtX1Rh | ||||
| Nnc3a3lGVWFCNUMxQnNJcnJMTWxka1dFaHluYmI4Ongtb2F1dGgtYmFzaWM=` | ||||
| 
 | ||||
| 	var verifiedSessionExpiry = time.Unix(1912151821, 0) | ||||
| 	var verifiedSession = &sessionsapi.SessionState{ | ||||
| 		AccessToken: verifiedToken, | ||||
| 		IDToken:     verifiedToken, | ||||
| 		Email:       "john@example.com", | ||||
| 		User:        "1234567890", | ||||
| 		ExpiresOn:   &verifiedSessionExpiry, | ||||
| 	} | ||||
| 
 | ||||
| 	// validToken will pass the token regex so can be used to check token fetching
 | ||||
| 	// is valid. It will not pass the OIDC Verifier however.
 | ||||
| 	const validToken = "eyJfoobar.eyJfoobar.12345asdf" | ||||
| 
 | ||||
| 	Context("JwtSessionLoader", func() { | ||||
| 		var verifier *oidc.IDTokenVerifier | ||||
| 		const nonVerifiedToken = validToken | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			keyset := noOpKeySet{} | ||||
| 			verifier = oidc.NewVerifier("https://issuer.example.com", keyset, | ||||
| 				&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) | ||||
| 		}) | ||||
| 
 | ||||
| 		type jwtSessionLoaderTableInput struct { | ||||
| 			authorizationHeader string | ||||
| 			existingSession     *sessionsapi.SessionState | ||||
| 			expectedSession     *sessionsapi.SessionState | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with an authorization header", | ||||
| 			func(in jwtSessionLoaderTableInput) { | ||||
| 				scope := &middlewareapi.RequestScope{ | ||||
| 					Session: in.existingSession, | ||||
| 				} | ||||
| 
 | ||||
| 				// Set up the request with the authorization header and a request scope
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				req.Header.Set("Authorization", in.authorizationHeader) | ||||
| 				contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 				req = req.WithContext(contextWithScope) | ||||
| 
 | ||||
| 				rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 				sessionLoaders := []middlewareapi.TokenToSessionLoader{ | ||||
| 					{ | ||||
| 						Verifier: verifier, | ||||
| 					}, | ||||
| 				} | ||||
| 
 | ||||
| 				// Create the handler with a next handler that will capture the session
 | ||||
| 				// from the scope
 | ||||
| 				var gotSession *sessionsapi.SessionState | ||||
| 				handler := NewJwtSessionLoader(sessionLoaders)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 					gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session | ||||
| 				})) | ||||
| 				handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 				Expect(gotSession).To(Equal(in.expectedSession)) | ||||
| 			}, | ||||
| 			Entry("<no value>", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("abcdef", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "abcdef", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("abcdef  (with existing session)", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "abcdef", | ||||
| 				existingSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 			}), | ||||
| 			Entry("Bearer <verifiedToken>", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     verifiedSession, | ||||
| 			}), | ||||
| 			Entry("Bearer <nonVerifiedToken>", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Bearer %s", nonVerifiedToken), | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Bearer <verifiedToken> (with existing session)", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), | ||||
| 				existingSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 				expectedSession:     &sessionsapi.SessionState{User: "user"}, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(<nonVerifiedToken>:) (No password)", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(<verifiedToken>:x-oauth-basic) (Sentinel password)", jwtSessionLoaderTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Basic %s", verifiedTokenXOAuthBasicBase64), | ||||
| 				existingSession:     nil, | ||||
| 				expectedSession:     verifiedSession, | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("getJWTSession", func() { | ||||
| 		var j *jwtSessionLoader | ||||
| 		const nonVerifiedToken = validToken | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			keyset := noOpKeySet{} | ||||
| 			verifier := oidc.NewVerifier("https://issuer.example.com", keyset, | ||||
| 				&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) | ||||
| 
 | ||||
| 			j = &jwtSessionLoader{ | ||||
| 				jwtRegex: regexp.MustCompile(jwtRegexFormat), | ||||
| 				sessionLoaders: []middlewareapi.TokenToSessionLoader{ | ||||
| 					{ | ||||
| 						Verifier:       verifier, | ||||
| 						TokenToSession: createSessionStateFromBearerToken, | ||||
| 					}, | ||||
| 				}, | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		type getJWTSessionTableInput struct { | ||||
| 			authorizationHeader string | ||||
| 			expectedErr         error | ||||
| 			expectedSession     *sessionsapi.SessionState | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with an authorization header", | ||||
| 			func(in getJWTSessionTableInput) { | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				req.Header.Set("Authorization", in.authorizationHeader) | ||||
| 
 | ||||
| 				session, err := j.getJwtSession(req) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(session).To(Equal(in.expectedSession)) | ||||
| 			}, | ||||
| 			Entry("<no value>", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: "", | ||||
| 				expectedErr:         nil, | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("abcdef", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: "abcdef", | ||||
| 				expectedErr:         errors.New("invalid authorization header: \"abcdef\""), | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Bearer abcdef", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: "Bearer abcdef", | ||||
| 				expectedErr:         errors.New("no valid bearer token found in authorization header"), | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Bearer <nonVerifiedToken>", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Bearer %s", nonVerifiedToken), | ||||
| 				expectedErr:         errors.New("unable to verify jwt token: \"Bearer eyJfoobar.eyJfoobar.12345asdf\""), | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Bearer <verifiedToken>", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Bearer %s", verifiedToken), | ||||
| 				expectedErr:         nil, | ||||
| 				expectedSession:     verifiedSession, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(<nonVerifiedToken>:) (No password)", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", | ||||
| 				expectedErr:         errors.New("unable to verify jwt token: \"Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6\""), | ||||
| 				expectedSession:     nil, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(<verifiedToken>:x-oauth-basic) (Sentinel password)", getJWTSessionTableInput{ | ||||
| 				authorizationHeader: fmt.Sprintf("Basic %s", verifiedTokenXOAuthBasicBase64), | ||||
| 				expectedErr:         nil, | ||||
| 				expectedSession:     verifiedSession, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("findBearerTokenFromHeader", func() { | ||||
| 		var j *jwtSessionLoader | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			j = &jwtSessionLoader{ | ||||
| 				jwtRegex: regexp.MustCompile(jwtRegexFormat), | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		type findBearerTokenFromHeaderTableInput struct { | ||||
| 			header        string | ||||
| 			expectedErr   error | ||||
| 			expectedToken string | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with a header", | ||||
| 			func(in findBearerTokenFromHeaderTableInput) { | ||||
| 				token, err := j.findBearerTokenFromHeader(in.header) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(token).To(Equal(in.expectedToken)) | ||||
| 			}, | ||||
| 			Entry("Bearer", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Bearer", | ||||
| 				expectedErr:   errors.New("invalid authorization header: \"Bearer\""), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Bearer abc def", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Bearer abc def", | ||||
| 				expectedErr:   errors.New("invalid authorization header: \"Bearer abc def\""), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Bearer abcdef", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Bearer abcdef", | ||||
| 				expectedErr:   errors.New("no valid bearer token found in authorization header"), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Bearer <valid-token>", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        fmt.Sprintf("Bearer %s", validToken), | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Basic invalid-base64", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Basic invalid-base64", | ||||
| 				expectedErr:   errors.New("invalid basic auth token: illegal base64 data at input byte 7"), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Basic Base64(<validToken>:) (No password)", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(<validToken>:x-oauth-basic) (Sentinel password)", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Basic ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6eC1vYXV0aC1iYXNpYw==", | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(any-user:<validToken>) (Matching password)", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Basic YW55LXVzZXI6ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY=", | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Basic Base64(any-user:any-password) (No matches)", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Basic YW55LXVzZXI6YW55LXBhc3N3b3Jk", | ||||
| 				expectedErr:   errors.New("invalid basic auth token found in authorization header"), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Basic Base64(any-user any-password) (Invalid format)", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        "Basic YW55LXVzZXIgYW55LXBhc3N3b3Jk", | ||||
| 				expectedErr:   errors.New("invalid format: \"any-user any-password\""), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Something <valid-token>", findBearerTokenFromHeaderTableInput{ | ||||
| 				header:        fmt.Sprintf("Something %s", validToken), | ||||
| 				expectedErr:   errors.New("no valid bearer token found in authorization header"), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 		) | ||||
| 
 | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("getBasicToken", func() { | ||||
| 		var j *jwtSessionLoader | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			j = &jwtSessionLoader{ | ||||
| 				jwtRegex: regexp.MustCompile(jwtRegexFormat), | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		type getBasicTokenTableInput struct { | ||||
| 			token         string | ||||
| 			expectedErr   error | ||||
| 			expectedToken string | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with a token", | ||||
| 			func(in getBasicTokenTableInput) { | ||||
| 				token, err := j.getBasicToken(in.token) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(token).To(Equal(in.expectedToken)) | ||||
| 			}, | ||||
| 			Entry("invalid-base64", getBasicTokenTableInput{ | ||||
| 				token:         "invalid-base64", | ||||
| 				expectedErr:   errors.New("invalid basic auth token: illegal base64 data at input byte 7"), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Base64(<validToken>:) (No password)", getBasicTokenTableInput{ | ||||
| 				token:         "ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6", | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Base64(<validToken>:x-oauth-basic) (Sentinel password)", getBasicTokenTableInput{ | ||||
| 				token:         "ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY6eC1vYXV0aC1iYXNpYw==", | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Base64(any-user:<validToken>) (Matching password)", getBasicTokenTableInput{ | ||||
| 				token:         "YW55LXVzZXI6ZXlKZm9vYmFyLmV5SmZvb2Jhci4xMjM0NWFzZGY=", | ||||
| 				expectedErr:   nil, | ||||
| 				expectedToken: validToken, | ||||
| 			}), | ||||
| 			Entry("Base64(any-user:any-password) (No matches)", getBasicTokenTableInput{ | ||||
| 				token:         "YW55LXVzZXI6YW55LXBhc3N3b3Jk", | ||||
| 				expectedErr:   errors.New("invalid basic auth token found in authorization header"), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 			Entry("Base64(any-user any-password) (Invalid format)", getBasicTokenTableInput{ | ||||
| 				token:         "YW55LXVzZXIgYW55LXBhc3N3b3Jk", | ||||
| 				expectedErr:   errors.New("invalid format: \"any-user any-password\""), | ||||
| 				expectedToken: "", | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("createSessionStateFromBearerToken", func() { | ||||
| 		ctx := context.Background() | ||||
| 		expiresFuture := time.Now().Add(time.Duration(5) * time.Minute) | ||||
| 		verified := true | ||||
| 		notVerified := false | ||||
| 
 | ||||
| 		type idTokenClaims struct { | ||||
| 			Email    string `json:"email,omitempty"` | ||||
| 			Verified *bool  `json:"email_verified,omitempty"` | ||||
| 			jwt.StandardClaims | ||||
| 		} | ||||
| 
 | ||||
| 		type createSessionStateTableInput struct { | ||||
| 			idToken         idTokenClaims | ||||
| 			expectedErr     error | ||||
| 			expectedUser    string | ||||
| 			expectedEmail   string | ||||
| 			expectedExpires *time.Time | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("when creating a session from an IDToken", | ||||
| 			func(in createSessionStateTableInput) { | ||||
| 				verifier := oidc.NewVerifier( | ||||
| 					"https://issuer.example.com", | ||||
| 					noOpKeySet{}, | ||||
| 					&oidc.Config{ClientID: "asdf1234"}, | ||||
| 				) | ||||
| 
 | ||||
| 				key, err := rsa.GenerateKey(rand.Reader, 2048) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, in.idToken).SignedString(key) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				// Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below
 | ||||
| 				idToken, err := verifier.Verify(context.Background(), rawIDToken) | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 				session, err := createSessionStateFromBearerToken(ctx, rawIDToken, idToken) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 					Expect(session).To(BeNil()) | ||||
| 					return | ||||
| 				} | ||||
| 
 | ||||
| 				Expect(err).ToNot(HaveOccurred()) | ||||
| 				Expect(session.AccessToken).To(Equal(rawIDToken)) | ||||
| 				Expect(session.IDToken).To(Equal(rawIDToken)) | ||||
| 				Expect(session.User).To(Equal(in.expectedUser)) | ||||
| 				Expect(session.Email).To(Equal(in.expectedEmail)) | ||||
| 				Expect(session.ExpiresOn.Unix()).To(Equal(in.expectedExpires.Unix())) | ||||
| 				Expect(session.RefreshToken).To(BeEmpty()) | ||||
| 				Expect(session.PreferredUsername).To(BeEmpty()) | ||||
| 			}, | ||||
| 			Entry("with no email", createSessionStateTableInput{ | ||||
| 				idToken: idTokenClaims{ | ||||
| 					StandardClaims: jwt.StandardClaims{ | ||||
| 						Audience:  "asdf1234", | ||||
| 						ExpiresAt: expiresFuture.Unix(), | ||||
| 						Id:        "id-some-id", | ||||
| 						IssuedAt:  time.Now().Unix(), | ||||
| 						Issuer:    "https://issuer.example.com", | ||||
| 						NotBefore: 0, | ||||
| 						Subject:   "123456789", | ||||
| 					}, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectedUser:    "123456789", | ||||
| 				expectedEmail:   "123456789", | ||||
| 				expectedExpires: &expiresFuture, | ||||
| 			}), | ||||
| 			Entry("with a verified email", createSessionStateTableInput{ | ||||
| 				idToken: idTokenClaims{ | ||||
| 					StandardClaims: jwt.StandardClaims{ | ||||
| 						Audience:  "asdf1234", | ||||
| 						ExpiresAt: expiresFuture.Unix(), | ||||
| 						Id:        "id-some-id", | ||||
| 						IssuedAt:  time.Now().Unix(), | ||||
| 						Issuer:    "https://issuer.example.com", | ||||
| 						NotBefore: 0, | ||||
| 						Subject:   "123456789", | ||||
| 					}, | ||||
| 					Email:    "foo@example.com", | ||||
| 					Verified: &verified, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectedUser:    "123456789", | ||||
| 				expectedEmail:   "foo@example.com", | ||||
| 				expectedExpires: &expiresFuture, | ||||
| 			}), | ||||
| 			Entry("with a non-verified email", createSessionStateTableInput{ | ||||
| 				idToken: idTokenClaims{ | ||||
| 					StandardClaims: jwt.StandardClaims{ | ||||
| 						Audience:  "asdf1234", | ||||
| 						ExpiresAt: expiresFuture.Unix(), | ||||
| 						Id:        "id-some-id", | ||||
| 						IssuedAt:  time.Now().Unix(), | ||||
| 						Issuer:    "https://issuer.example.com", | ||||
| 						NotBefore: 0, | ||||
| 						Subject:   "123456789", | ||||
| 					}, | ||||
| 					Email:    "foo@example.com", | ||||
| 					Verified: ¬Verified, | ||||
| 				}, | ||||
| 				expectedErr: errors.New("email in id_token (foo@example.com) isn't verified"), | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -0,0 +1,39 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| ) | ||||
| 
 | ||||
| type scopeKey string | ||||
| 
 | ||||
| // requestScopeKey uses a typed string to reduce likelihood of clasing
 | ||||
| // with other context keys
 | ||||
| const requestScopeKey scopeKey = "request-scope" | ||||
| 
 | ||||
| func NewScope() alice.Constructor { | ||||
| 	return addScope | ||||
| } | ||||
| 
 | ||||
| // addScope injects a new request scope into the request context.
 | ||||
| func addScope(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := &middlewareapi.RequestScope{} | ||||
| 		contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 		requestWithScope := req.WithContext(contextWithScope) | ||||
| 		next.ServeHTTP(rw, requestWithScope) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // GetRequestScope returns the current request scope from the given request
 | ||||
| func GetRequestScope(req *http.Request) *middlewareapi.RequestScope { | ||||
| 	scope := req.Context().Value(requestScopeKey) | ||||
| 	if scope == nil { | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	return scope.(*middlewareapi.RequestScope) | ||||
| } | ||||
|  | @ -0,0 +1,94 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Scope Suite", func() { | ||||
| 	Context("NewScope", func() { | ||||
| 		var request, nextRequest *http.Request | ||||
| 		var rw http.ResponseWriter | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			var err error | ||||
| 			request, err = http.NewRequest("", "http://127.0.0.1/", nil) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 
 | ||||
| 			rw = httptest.NewRecorder() | ||||
| 
 | ||||
| 			handler := NewScope()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 				nextRequest = r | ||||
| 				w.WriteHeader(200) | ||||
| 			})) | ||||
| 			handler.ServeHTTP(rw, request) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("does not add a scope to the original request", func() { | ||||
| 			Expect(request.Context().Value(requestScopeKey)).To(BeNil()) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("cannot load a scope from the original request using GetRequestScope", func() { | ||||
| 			Expect(GetRequestScope(request)).To(BeNil()) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("adds a scope to the request for the next handler", func() { | ||||
| 			Expect(nextRequest.Context().Value(requestScopeKey)).ToNot(BeNil()) | ||||
| 		}) | ||||
| 
 | ||||
| 		It("can load a scope from the next handler's request using GetRequestScope", func() { | ||||
| 			Expect(GetRequestScope(nextRequest)).ToNot(BeNil()) | ||||
| 		}) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("GetRequestScope", func() { | ||||
| 		var request *http.Request | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			var err error | ||||
| 			request, err = http.NewRequest("", "http://127.0.0.1/", nil) | ||||
| 			Expect(err).ToNot(HaveOccurred()) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with a scope", func() { | ||||
| 			var scope *middlewareapi.RequestScope | ||||
| 
 | ||||
| 			BeforeEach(func() { | ||||
| 				scope = &middlewareapi.RequestScope{} | ||||
| 				contextWithScope := context.WithValue(request.Context(), requestScopeKey, scope) | ||||
| 				request = request.WithContext(contextWithScope) | ||||
| 			}) | ||||
| 
 | ||||
| 			It("returns the scope", func() { | ||||
| 				s := GetRequestScope(request) | ||||
| 				Expect(s).ToNot(BeNil()) | ||||
| 				Expect(s).To(Equal(scope)) | ||||
| 			}) | ||||
| 
 | ||||
| 			Context("if the scope is then modified", func() { | ||||
| 				BeforeEach(func() { | ||||
| 					Expect(scope.SaveSession).To(BeFalse()) | ||||
| 					scope.SaveSession = true | ||||
| 				}) | ||||
| 
 | ||||
| 				It("returns the updated session", func() { | ||||
| 					s := GetRequestScope(request) | ||||
| 					Expect(s).ToNot(BeNil()) | ||||
| 					Expect(s).To(Equal(scope)) | ||||
| 					Expect(s.SaveSession).To(BeTrue()) | ||||
| 				}) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("without a scope", func() { | ||||
| 			It("returns nil", func() { | ||||
| 				Expect(GetRequestScope(request)).To(BeNil()) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -0,0 +1,33 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
| 
 | ||||
| // splitAuthHeader takes the auth header value and splits it into the token type
 | ||||
| // and the token value.
 | ||||
| func splitAuthHeader(header string) (string, string, error) { | ||||
| 	s := strings.Split(header, " ") | ||||
| 	if len(s) != 2 { | ||||
| 		return "", "", fmt.Errorf("invalid authorization header: %q", header) | ||||
| 	} | ||||
| 	return s[0], s[1], nil | ||||
| } | ||||
| 
 | ||||
| // getBasicAuthCredentials decodes a basic auth token and extracts the user
 | ||||
| // and password pair.
 | ||||
| func getBasicAuthCredentials(token string) (string, string, error) { | ||||
| 	b, err := base64.StdEncoding.DecodeString(token) | ||||
| 	if err != nil { | ||||
| 		return "", "", fmt.Errorf("invalid basic auth token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	pair := strings.SplitN(string(b), ":", 2) | ||||
| 	if len(pair) != 2 { | ||||
| 		return "", "", fmt.Errorf("invalid format: %q", b) | ||||
| 	} | ||||
| 	// user, password
 | ||||
| 	return pair[0], pair[1], nil | ||||
| } | ||||
|  | @ -0,0 +1,103 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 
 | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Session utilities suite", func() { | ||||
| 	Context("splitAuthHeader", func() { | ||||
| 		type splitAuthTableInput struct { | ||||
| 			header             string | ||||
| 			expectedErr        error | ||||
| 			expectedTokenType  string | ||||
| 			expectedTokenValue string | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("with a header value", | ||||
| 			func(in splitAuthTableInput) { | ||||
| 				tt, tv, err := splitAuthHeader(in.header) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(tt).To(Equal(in.expectedTokenType)) | ||||
| 				Expect(tv).To(Equal(in.expectedTokenValue)) | ||||
| 			}, | ||||
| 			Entry("Bearer abcdef", splitAuthTableInput{ | ||||
| 				header:             "Bearer abcdef", | ||||
| 				expectedErr:        nil, | ||||
| 				expectedTokenType:  "Bearer", | ||||
| 				expectedTokenValue: "abcdef", | ||||
| 			}), | ||||
| 			Entry("Bearer", splitAuthTableInput{ | ||||
| 				header:             "Bearer", | ||||
| 				expectedErr:        errors.New("invalid authorization header: \"Bearer\""), | ||||
| 				expectedTokenType:  "", | ||||
| 				expectedTokenValue: "", | ||||
| 			}), | ||||
| 			Entry("Bearer abc def", splitAuthTableInput{ | ||||
| 				header:             "Bearer abc def", | ||||
| 				expectedErr:        errors.New("invalid authorization header: \"Bearer abc def\""), | ||||
| 				expectedTokenType:  "", | ||||
| 				expectedTokenValue: "", | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("getBasicAuthCredentials", func() { | ||||
| 		type getBasicAuthCredentialsTableInput struct { | ||||
| 			token            string | ||||
| 			expectedErr      error | ||||
| 			expectedUser     string | ||||
| 			expectedPassword string | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("from token", | ||||
| 			func(in getBasicAuthCredentialsTableInput) { | ||||
| 				user, password, err := getBasicAuthCredentials(in.token) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(user).To(Equal(in.expectedUser)) | ||||
| 				Expect(password).To(Equal(in.expectedPassword)) | ||||
| 			}, | ||||
| 			Entry("<no value>", getBasicAuthCredentialsTableInput{ | ||||
| 				token:            "", | ||||
| 				expectedErr:      errors.New("invalid format: \"\""), | ||||
| 				expectedUser:     "", | ||||
| 				expectedPassword: "", | ||||
| 			}), | ||||
| 			Entry("invalid-base64", getBasicAuthCredentialsTableInput{ | ||||
| 				token:            "invalid-base64", | ||||
| 				expectedErr:      errors.New("invalid basic auth token: illegal base64 data at input byte 7"), | ||||
| 				expectedUser:     "", | ||||
| 				expectedPassword: "", | ||||
| 			}), | ||||
| 			Entry("Base64(some-user:some-password)", getBasicAuthCredentialsTableInput{ | ||||
| 				token:            "c29tZS11c2VyOnNvbWUtcGFzc3dvcmQ=", | ||||
| 				expectedErr:      nil, | ||||
| 				expectedUser:     "some-user", | ||||
| 				expectedPassword: "some-password", | ||||
| 			}), | ||||
| 			Entry("Base64(no-password:)", getBasicAuthCredentialsTableInput{ | ||||
| 				token:            "bm8tcGFzc3dvcmQ6", | ||||
| 				expectedErr:      nil, | ||||
| 				expectedUser:     "no-password", | ||||
| 				expectedPassword: "", | ||||
| 			}), | ||||
| 			Entry("Base64(:no-user)", getBasicAuthCredentialsTableInput{ | ||||
| 				token:            "Om5vLXVzZXI=", | ||||
| 				expectedErr:      nil, | ||||
| 				expectedUser:     "", | ||||
| 				expectedPassword: "no-user", | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| }) | ||||
|  | @ -0,0 +1,165 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/justinas/alice" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/logger" | ||||
| ) | ||||
| 
 | ||||
| // StoredSessionLoaderOptions cotnains all of the requirements to construct
 | ||||
| // a stored session loader.
 | ||||
| // All options must be provided.
 | ||||
| type StoredSessionLoaderOptions struct { | ||||
| 	// Session storage basckend
 | ||||
| 	SessionStore sessionsapi.SessionStore | ||||
| 
 | ||||
| 	// How often should sessions be refreshed
 | ||||
| 	RefreshPeriod time.Duration | ||||
| 
 | ||||
| 	// Provider based sesssion refreshing
 | ||||
| 	RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 
 | ||||
| 	// Provider based session validation.
 | ||||
| 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | ||||
| 	// refresh it, we must re-validate using this validation.
 | ||||
| 	ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool | ||||
| } | ||||
| 
 | ||||
| // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | ||||
| // sessions from the session store.
 | ||||
| // If no session is found, the request will be passed to the nex handler.
 | ||||
| // If a session was loader by a previous handler, it will not be replaced.
 | ||||
| func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor { | ||||
| 	ss := &storedSessionLoader{ | ||||
| 		store:                              opts.SessionStore, | ||||
| 		refreshPeriod:                      opts.RefreshPeriod, | ||||
| 		refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, | ||||
| 		validateSessionState:               opts.ValidateSessionState, | ||||
| 	} | ||||
| 	return ss.loadSession | ||||
| } | ||||
| 
 | ||||
| // storedSessionLoader is responsible for loading sessions from cookie
 | ||||
| // identified sessions in the session store.
 | ||||
| type storedSessionLoader struct { | ||||
| 	store                              sessionsapi.SessionStore | ||||
| 	refreshPeriod                      time.Duration | ||||
| 	refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 	validateSessionState               func(context.Context, *sessionsapi.SessionState) bool | ||||
| } | ||||
| 
 | ||||
| // loadSession attempts to load a session as identified by the request cookies.
 | ||||
| // If no session is found, the request will be passed to the nex handler.
 | ||||
| // If a session was loader by a previous handler, it will not be replaced.
 | ||||
| func (s *storedSessionLoader) loadSession(next http.Handler) http.Handler { | ||||
| 	return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||||
| 		scope := GetRequestScope(req) | ||||
| 		// If scope is nil, this will panic.
 | ||||
| 		// A scope should always be injected before this handler is called.
 | ||||
| 		if scope.Session != nil { | ||||
| 			// The session was already loaded, pass to the next handler
 | ||||
| 			next.ServeHTTP(rw, req) | ||||
| 			return | ||||
| 		} | ||||
| 
 | ||||
| 		session, err := s.getValidatedSession(rw, req) | ||||
| 		if err != nil { | ||||
| 			// In the case when there was an error loading the session,
 | ||||
| 			// we should clear the session
 | ||||
| 			logger.Printf("Error loading cookied session: %v, removing session", err) | ||||
| 			s.store.Clear(rw, req) | ||||
| 		} | ||||
| 
 | ||||
| 		// Add the session to the scope if it was found
 | ||||
| 		scope.Session = session | ||||
| 		next.ServeHTTP(rw, req) | ||||
| 	}) | ||||
| } | ||||
| 
 | ||||
| // getValidatedSession is responsible for loading a session and making sure
 | ||||
| // that is is valid.
 | ||||
| func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	session, err := s.store.Load(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if session == nil { | ||||
| 		// No session was found in the storage, nothing more to do
 | ||||
| 		return nil, nil | ||||
| 	} | ||||
| 
 | ||||
| 	err = s.refreshSessionIfNeeded(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("error refreshing access token for session (%s): %v", session, err) | ||||
| 	} | ||||
| 
 | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // refreshSessionIfNeeded will attempt to refresh a session if the session
 | ||||
| // is older than the refresh period.
 | ||||
| // It is assumed that if the provider refreshes the session, the session is now
 | ||||
| // valid.
 | ||||
| // If the session requires refreshing but the provider does not refresh it,
 | ||||
| // we must validate the session to ensure that the returned session is still
 | ||||
| // valid.
 | ||||
| func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||
| 	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { | ||||
| 		// Refresh is disabled or the session is not old enough, do nothing
 | ||||
| 		return nil | ||||
| 	} | ||||
| 
 | ||||
| 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | ||||
| 	refreshed, err := s.refreshSessionWithProvider(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	if !refreshed { | ||||
| 		// Session wasn't refreshed, so make sure it's still valid
 | ||||
| 		return s.validateSession(req.Context(), session) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // refreshSessionWithProvider attempts to refresh the sessinon with the provider
 | ||||
| // and will save the session if it was updated.
 | ||||
| func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { | ||||
| 	refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) | ||||
| 	if err != nil { | ||||
| 		return false, fmt.Errorf("error refreshing access token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if !refreshed { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	// Because the session was refreshed, make sure to save it
 | ||||
| 	err = s.store.Save(rw, req, session) | ||||
| 	if err != nil { | ||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||
| 		return false, fmt.Errorf("error saving session: %v", err) | ||||
| 	} | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| // validateSession checks whether the session has expired and performs
 | ||||
| // provider validation on the session.
 | ||||
| // An error implies the session is not longer valid.
 | ||||
| func (s *storedSessionLoader) validateSession(ctx context.Context, session *sessionsapi.SessionState) error { | ||||
| 	if session.IsExpired() { | ||||
| 		return errors.New("session is expired") | ||||
| 	} | ||||
| 
 | ||||
| 	if !s.validateSessionState(ctx, session) { | ||||
| 		return errors.New("session is invalid") | ||||
| 	} | ||||
| 
 | ||||
| 	return nil | ||||
| } | ||||
|  | @ -0,0 +1,524 @@ | |||
| package middleware | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"time" | ||||
| 
 | ||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/middleware" | ||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	. "github.com/onsi/ginkgo" | ||||
| 	. "github.com/onsi/ginkgo/extensions/table" | ||||
| 	. "github.com/onsi/gomega" | ||||
| ) | ||||
| 
 | ||||
| var _ = Describe("Stored Session Suite", func() { | ||||
| 	const ( | ||||
| 		refresh   = "Refresh" | ||||
| 		noRefresh = "NoRefresh" | ||||
| 	) | ||||
| 
 | ||||
| 	var ctx = context.Background() | ||||
| 
 | ||||
| 	Context("StoredSessionLoader", func() { | ||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | ||||
| 		createdFuture := time.Now().Add(5 * time.Minute) | ||||
| 
 | ||||
| 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 			switch ss.RefreshToken { | ||||
| 			case refresh: | ||||
| 				ss.RefreshToken = "Refreshed" | ||||
| 				return true, nil | ||||
| 			case noRefresh: | ||||
| 				return false, nil | ||||
| 			default: | ||||
| 				return false, errors.New("error refreshing session") | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		var defaultValidateFunc = func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||
| 			return ss.AccessToken != "Invalid" | ||||
| 		} | ||||
| 
 | ||||
| 		var defaultSessionStore = &fakeSessionStore{ | ||||
| 			LoadFunc: func(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 				switch req.Header.Get("Cookie") { | ||||
| 				case "_oauth2_proxy=NoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=InvalidNoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						AccessToken:  "Invalid", | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=ExpiredNoRefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: noRefresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdPast, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=RefreshSession": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: refresh, | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=RefreshError": | ||||
| 					return &sessionsapi.SessionState{ | ||||
| 						RefreshToken: "RefreshError", | ||||
| 						CreatedAt:    &createdPast, | ||||
| 						ExpiresOn:    &createdFuture, | ||||
| 					}, nil | ||||
| 				case "_oauth2_proxy=NonExistent": | ||||
| 					return nil, fmt.Errorf("invalid cookie") | ||||
| 				default: | ||||
| 					return nil, nil | ||||
| 				} | ||||
| 			}, | ||||
| 		} | ||||
| 
 | ||||
| 		type storedSessionLoaderTableInput struct { | ||||
| 			requestHeaders  http.Header | ||||
| 			existingSession *sessionsapi.SessionState | ||||
| 			expectedSession *sessionsapi.SessionState | ||||
| 			store           sessionsapi.SessionStore | ||||
| 			refreshPeriod   time.Duration | ||||
| 			refreshSession  func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||
| 			validateSession func(context.Context, *sessionsapi.SessionState) bool | ||||
| 		} | ||||
| 
 | ||||
| 		DescribeTable("when serving a request", | ||||
| 			func(in storedSessionLoaderTableInput) { | ||||
| 				scope := &middlewareapi.RequestScope{ | ||||
| 					Session: in.existingSession, | ||||
| 				} | ||||
| 
 | ||||
| 				// Set up the request with the request headesr and a request scope
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				req.Header = in.requestHeaders | ||||
| 				contextWithScope := context.WithValue(req.Context(), requestScopeKey, scope) | ||||
| 				req = req.WithContext(contextWithScope) | ||||
| 
 | ||||
| 				rw := httptest.NewRecorder() | ||||
| 
 | ||||
| 				opts := &StoredSessionLoaderOptions{ | ||||
| 					SessionStore:           in.store, | ||||
| 					RefreshPeriod:          in.refreshPeriod, | ||||
| 					RefreshSessionIfNeeded: in.refreshSession, | ||||
| 					ValidateSessionState:   in.validateSession, | ||||
| 				} | ||||
| 
 | ||||
| 				// Create the handler with a next handler that will capture the session
 | ||||
| 				// from the scope
 | ||||
| 				var gotSession *sessionsapi.SessionState | ||||
| 				handler := NewStoredSessionLoader(opts)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 					gotSession = r.Context().Value(requestScopeKey).(*middlewareapi.RequestScope).Session | ||||
| 				})) | ||||
| 				handler.ServeHTTP(rw, req) | ||||
| 
 | ||||
| 				Expect(gotSession).To(Equal(in.expectedSession)) | ||||
| 			}, | ||||
| 			Entry("with no cookie", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders:  http.Header{}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with an invalid cookie", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=NonExistent"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with an existing session", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "Existing", | ||||
| 				}, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "Existing", | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that has not expired", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=NoRefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: noRefresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that cannot refresh and has expired", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=ExpiredNoRefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that can refresh, but is younger than refresh period", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   10 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("with a session that can refresh and is older than the refresh period", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "Refreshed", | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("when the provider refresh fails", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshError"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 			Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ | ||||
| 				requestHeaders: http.Header{ | ||||
| 					"Cookie": []string{"_oauth2_proxy=InvalidNoRefreshSession"}, | ||||
| 				}, | ||||
| 				existingSession: nil, | ||||
| 				expectedSession: nil, | ||||
| 				store:           defaultSessionStore, | ||||
| 				refreshPeriod:   1 * time.Minute, | ||||
| 				refreshSession:  defaultRefreshFunc, | ||||
| 				validateSession: defaultValidateFunc, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("refreshSessionIfNeeded", func() { | ||||
| 		type refreshSessionIfNeededTableInput struct { | ||||
| 			refreshPeriod   time.Duration | ||||
| 			session         *sessionsapi.SessionState | ||||
| 			expectedErr     error | ||||
| 			expectRefreshed bool | ||||
| 			expectValidated bool | ||||
| 		} | ||||
| 
 | ||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | ||||
| 		createdFuture := time.Now().Add(5 * time.Minute) | ||||
| 
 | ||||
| 		DescribeTable("with a session", | ||||
| 			func(in refreshSessionIfNeededTableInput) { | ||||
| 				refreshed := false | ||||
| 				validated := false | ||||
| 
 | ||||
| 				s := &storedSessionLoader{ | ||||
| 					refreshPeriod: in.refreshPeriod, | ||||
| 					store:         &fakeSessionStore{}, | ||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 						refreshed = true | ||||
| 						switch ss.RefreshToken { | ||||
| 						case refresh: | ||||
| 							return true, nil | ||||
| 						case noRefresh: | ||||
| 							return false, nil | ||||
| 						default: | ||||
| 							return false, errors.New("error refreshing session") | ||||
| 						} | ||||
| 					}, | ||||
| 					validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||
| 						validated = true | ||||
| 						return ss.AccessToken != "Invalid" | ||||
| 					}, | ||||
| 				} | ||||
| 
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				err := s.refreshSessionIfNeeded(nil, req, in.session) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||
| 				Expect(validated).To(Equal(in.expectValidated)) | ||||
| 			}, | ||||
| 			Entry("when the refresh period is 0, and the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: time.Duration(0), | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the refresh period is 0, and the session needs refreshing", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: time.Duration(0), | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session does not need refreshing", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session is refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: noRefresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: true, | ||||
| 			}), | ||||
| 			Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "RefreshError", | ||||
| 					CreatedAt:    &createdPast, | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: false, | ||||
| 			}), | ||||
| 			Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ | ||||
| 				refreshPeriod: 1 * time.Minute, | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					AccessToken:  "Invalid", | ||||
| 					RefreshToken: noRefresh, | ||||
| 					CreatedAt:    &createdPast, | ||||
| 					ExpiresOn:    &createdFuture, | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("session is invalid"), | ||||
| 				expectRefreshed: true, | ||||
| 				expectValidated: true, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("refreshSessionWithProvider", func() { | ||||
| 		type refreshSessionWithProviderTableInput struct { | ||||
| 			session         *sessionsapi.SessionState | ||||
| 			expectedErr     error | ||||
| 			expectRefreshed bool | ||||
| 			expectSaved     bool | ||||
| 		} | ||||
| 
 | ||||
| 		now := time.Now() | ||||
| 
 | ||||
| 		DescribeTable("when refreshing with the provider", | ||||
| 			func(in refreshSessionWithProviderTableInput) { | ||||
| 				saved := false | ||||
| 
 | ||||
| 				s := &storedSessionLoader{ | ||||
| 					store: &fakeSessionStore{ | ||||
| 						SaveFunc: func(_ http.ResponseWriter, _ *http.Request, ss *sessionsapi.SessionState) error { | ||||
| 							saved = true | ||||
| 							if ss.AccessToken == "NoSave" { | ||||
| 								return errors.New("unable to save session") | ||||
| 							} | ||||
| 							return nil | ||||
| 						}, | ||||
| 					}, | ||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||
| 						switch ss.RefreshToken { | ||||
| 						case refresh: | ||||
| 							return true, nil | ||||
| 						case noRefresh: | ||||
| 							return false, nil | ||||
| 						default: | ||||
| 							return false, errors.New("error refreshing session") | ||||
| 						} | ||||
| 					}, | ||||
| 				} | ||||
| 
 | ||||
| 				req := httptest.NewRequest("", "/", nil) | ||||
| 				refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) | ||||
| 				if in.expectedErr != nil { | ||||
| 					Expect(err).To(MatchError(in.expectedErr)) | ||||
| 				} else { | ||||
| 					Expect(err).ToNot(HaveOccurred()) | ||||
| 				} | ||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) | ||||
| 				Expect(saved).To(Equal(in.expectSaved)) | ||||
| 			}, | ||||
| 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: noRefresh, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: false, | ||||
| 				expectSaved:     false, | ||||
| 			}), | ||||
| 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 				}, | ||||
| 				expectedErr:     nil, | ||||
| 				expectRefreshed: true, | ||||
| 				expectSaved:     true, | ||||
| 			}), | ||||
| 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: "RefreshError", | ||||
| 					CreatedAt:    &now, | ||||
| 					ExpiresOn:    &now, | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | ||||
| 				expectRefreshed: false, | ||||
| 				expectSaved:     false, | ||||
| 			}), | ||||
| 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | ||||
| 				session: &sessionsapi.SessionState{ | ||||
| 					RefreshToken: refresh, | ||||
| 					AccessToken:  "NoSave", | ||||
| 				}, | ||||
| 				expectedErr:     errors.New("error saving session: unable to save session"), | ||||
| 				expectRefreshed: false, | ||||
| 				expectSaved:     true, | ||||
| 			}), | ||||
| 		) | ||||
| 	}) | ||||
| 
 | ||||
| 	Context("validateSession", func() { | ||||
| 		var s *storedSessionLoader | ||||
| 
 | ||||
| 		BeforeEach(func() { | ||||
| 			s = &storedSessionLoader{ | ||||
| 				validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||
| 					return ss.AccessToken == "Valid" | ||||
| 				}, | ||||
| 			} | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with a valid session", func() { | ||||
| 			It("does not return an error", func() { | ||||
| 				expires := time.Now().Add(1 * time.Minute) | ||||
| 				session := &sessionsapi.SessionState{ | ||||
| 					AccessToken: "Valid", | ||||
| 					ExpiresOn:   &expires, | ||||
| 				} | ||||
| 				Expect(s.validateSession(ctx, session)).To(Succeed()) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an expired session", func() { | ||||
| 			It("returns an error", func() { | ||||
| 				created := time.Now().Add(-5 * time.Minute) | ||||
| 				expires := time.Now().Add(-1 * time.Minute) | ||||
| 				session := &sessionsapi.SessionState{ | ||||
| 					AccessToken: "Valid", | ||||
| 					CreatedAt:   &created, | ||||
| 					ExpiresOn:   &expires, | ||||
| 				} | ||||
| 				Expect(s.validateSession(ctx, session)).To(MatchError("session is expired")) | ||||
| 			}) | ||||
| 		}) | ||||
| 
 | ||||
| 		Context("with an invalid session", func() { | ||||
| 			It("returns an error", func() { | ||||
| 				expires := time.Now().Add(1 * time.Minute) | ||||
| 				session := &sessionsapi.SessionState{ | ||||
| 					AccessToken: "Invalid", | ||||
| 					ExpiresOn:   &expires, | ||||
| 				} | ||||
| 				Expect(s.validateSession(ctx, session)).To(MatchError("session is invalid")) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| }) | ||||
| 
 | ||||
| type fakeSessionStore struct { | ||||
| 	SaveFunc  func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error | ||||
| 	LoadFunc  func(req *http.Request) (*sessionsapi.SessionState, error) | ||||
| 	ClearFunc func(rw http.ResponseWriter, req *http.Request) error | ||||
| } | ||||
| 
 | ||||
| func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { | ||||
| 	if f.SaveFunc != nil { | ||||
| 		return f.SaveFunc(rw, req, s) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| func (f *fakeSessionStore) Load(req *http.Request) (*sessionsapi.SessionState, error) { | ||||
| 	if f.LoadFunc != nil { | ||||
| 		return f.LoadFunc(req) | ||||
| 	} | ||||
| 	return nil, nil | ||||
| } | ||||
| 
 | ||||
| func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { | ||||
| 	if f.ClearFunc != nil { | ||||
| 		return f.ClearFunc(rw, req) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | @ -126,35 +126,8 @@ func (p *ProviderData) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S | |||
| 	return false, nil | ||||
| } | ||||
| 
 | ||||
| // CreateSessionStateFromBearerToken should be implemented to allow providers
 | ||||
| // to convert ID tokens into sessions
 | ||||
| func (p *ProviderData) CreateSessionStateFromBearerToken(ctx context.Context, rawIDToken string, idToken *oidc.IDToken) (*sessions.SessionState, error) { | ||||
| 	var claims struct { | ||||
| 		Subject           string `json:"sub"` | ||||
| 		Email             string `json:"email"` | ||||
| 		Verified          *bool  `json:"email_verified"` | ||||
| 		PreferredUsername string `json:"preferred_username"` | ||||
| 	} | ||||
| 
 | ||||
| 	if err := idToken.Claims(&claims); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to parse bearer token claims: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.Email == "" { | ||||
| 		claims.Email = claims.Subject | ||||
| 	} | ||||
| 
 | ||||
| 	if claims.Verified != nil && !*claims.Verified { | ||||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	newSession := &sessions.SessionState{ | ||||
| 		Email:             claims.Email, | ||||
| 		User:              claims.Subject, | ||||
| 		PreferredUsername: claims.PreferredUsername, | ||||
| 		AccessToken:       rawIDToken, | ||||
| 		IDToken:           rawIDToken, | ||||
| 		RefreshToken:      "", | ||||
| 		ExpiresOn:         &idToken.Expiry, | ||||
| 	} | ||||
| 
 | ||||
| 	return newSession, nil | ||||
| 	return nil, errors.New("not implemented") | ||||
| } | ||||
|  |  | |||
|  | @ -2,15 +2,10 @@ package providers | |||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/rsa" | ||||
| 	"net/url" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/coreos/go-oidc" | ||||
| 	"github.com/dgrijalva/jwt-go" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | @ -52,39 +47,3 @@ func TestAcrValuesConfigured(t *testing.T) { | |||
| 	result := p.GetLoginURL("https://my.test.app/oauth", "") | ||||
| 	assert.Contains(t, result, "acr_values=testValue") | ||||
| } | ||||
| 
 | ||||
| func TestCreateSessionStateFromBearerToken(t *testing.T) { | ||||
| 	minimalIDToken := jwt.StandardClaims{ | ||||
| 		Audience:  "asdf1234", | ||||
| 		ExpiresAt: time.Now().Add(time.Duration(5) * time.Minute).Unix(), | ||||
| 		Id:        "id-some-id", | ||||
| 		IssuedAt:  time.Now().Unix(), | ||||
| 		Issuer:    "https://issuer.example.com", | ||||
| 		NotBefore: 0, | ||||
| 		Subject:   "123456789", | ||||
| 	} | ||||
| 	// From oidc_test.go
 | ||||
| 	verifier := oidc.NewVerifier( | ||||
| 		"https://issuer.example.com", | ||||
| 		fakeKeySetStub{}, | ||||
| 		&oidc.Config{ClientID: "asdf1234"}, | ||||
| 	) | ||||
| 
 | ||||
| 	key, err := rsa.GenerateKey(rand.Reader, 2048) | ||||
| 	assert.NoError(t, err) | ||||
| 	rawIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS256, minimalIDToken).SignedString(key) | ||||
| 	assert.NoError(t, err) | ||||
| 	// Pass to a dummy Verifier to get an oidc.IDToken from the rawIDToken for our actual test below
 | ||||
| 	idToken, err := verifier.Verify(context.Background(), rawIDToken) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	session, err := (*ProviderData)(nil).CreateSessionStateFromBearerToken(context.Background(), rawIDToken, idToken) | ||||
| 	assert.NoError(t, err) | ||||
| 
 | ||||
| 	assert.Equal(t, rawIDToken, session.AccessToken) | ||||
| 	assert.Equal(t, rawIDToken, session.IDToken) | ||||
| 	assert.Equal(t, "123456789", session.Email) | ||||
| 	assert.Equal(t, "123456789", session.User) | ||||
| 	assert.Empty(t, session.RefreshToken) | ||||
| 	assert.Empty(t, session.PreferredUsername) | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue