Merge pull request #1086 from oauth2-proxy/early-refresh
Convert RefreshSessionIfNeeded into RefreshSession
This commit is contained in:
		
						commit
						16a9893a19
					
				|  | @ -4,10 +4,16 @@ | ||||||
| 
 | 
 | ||||||
| ## Important Notes | ## Important Notes | ||||||
| 
 | 
 | ||||||
|  | - [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) The extra validation to protect invalid session | ||||||
|  |   deserialization from v6.0.0 (only) has been removed to improve performance. If you are on v6.0.0, either upgrade | ||||||
|  |   to a version before this first and allow legacy sessions to expire gracefully or change your `cookie-secret` | ||||||
|  |   value and force all sessions to reauthenticate. | ||||||
|  | 
 | ||||||
| ## Breaking Changes | ## Breaking Changes | ||||||
| 
 | 
 | ||||||
| ## Changes since v7.1.3 | ## Changes since v7.1.3 | ||||||
| 
 | 
 | ||||||
|  | - [#1086](https://github.com/oauth2-proxy/oauth2-proxy/pull/1086) Refresh sessions before token expiration if configured (@NickMeves) | ||||||
| - [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed) | - [#1226](https://github.com/oauth2-proxy/oauth2-proxy/pull/1226) Move app redirection logic to its own package (@JoelSpeed) | ||||||
| - [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed) | - [#1128](https://github.com/oauth2-proxy/oauth2-proxy/pull/1128) Use gorilla mux for OAuth Proxy routing (@JoelSpeed) | ||||||
| - [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi) | - [#1238](https://github.com/oauth2-proxy/oauth2-proxy/pull/1238) Added ADFS provider (@samirachoadi) | ||||||
|  |  | ||||||
|  | @ -363,8 +363,8 @@ func buildSessionChain(opts *options.Options, sessionStore sessionsapi.SessionSt | ||||||
| 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | 	chain = chain.Append(middleware.NewStoredSessionLoader(&middleware.StoredSessionLoaderOptions{ | ||||||
| 		SessionStore:    sessionStore, | 		SessionStore:    sessionStore, | ||||||
| 		RefreshPeriod:   opts.Cookie.Refresh, | 		RefreshPeriod:   opts.Cookie.Refresh, | ||||||
| 		RefreshSessionIfNeeded: opts.GetProvider().RefreshSessionIfNeeded, | 		RefreshSession:  opts.GetProvider().RefreshSession, | ||||||
| 		ValidateSessionState:   opts.GetProvider().ValidateSession, | 		ValidateSession: opts.GetProvider().ValidateSession, | ||||||
| 	})) | 	})) | ||||||
| 
 | 
 | ||||||
| 	return chain | 	return chain | ||||||
|  | @ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	// Force setting these in case the Provider didn't
 | ||||||
|  | 	if s.CreatedAt == nil { | ||||||
|  | 		s.CreatedAtNow() | ||||||
|  | 	} | ||||||
|  | 	if s.ExpiresOn == nil { | ||||||
|  | 		s.ExpiresIn(p.CookieOptions.Expire) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	return s, nil | 	return s, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -3,14 +3,12 @@ package sessions | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"reflect" |  | ||||||
| 	"time" | 	"time" | ||||||
| 	"unicode/utf8" |  | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||||
| 	"github.com/pierrec/lz4" | 	"github.com/pierrec/lz4" | ||||||
| 	"github.com/vmihailenco/msgpack/v4" | 	"github.com/vmihailenco/msgpack/v4" | ||||||
|  | @ -32,6 +30,8 @@ type SessionState struct { | ||||||
| 	Groups            []string `msgpack:"g,omitempty"` | 	Groups            []string `msgpack:"g,omitempty"` | ||||||
| 	PreferredUsername string   `msgpack:"pu,omitempty"` | 	PreferredUsername string   `msgpack:"pu,omitempty"` | ||||||
| 
 | 
 | ||||||
|  | 	// Internal helpers, not serialized
 | ||||||
|  | 	Clock clock.Clock `msgpack:"-"` | ||||||
| 	Lock  Lock        `msgpack:"-"` | 	Lock  Lock        `msgpack:"-"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -63,9 +63,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { | ||||||
| 	return s.Lock.Peek(ctx) | 	return s.Lock.Peek(ctx) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CreatedAtNow sets a SessionState's CreatedAt to now
 | ||||||
|  | func (s *SessionState) CreatedAtNow() { | ||||||
|  | 	now := s.Clock.Now() | ||||||
|  | 	s.CreatedAt = &now | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // SetExpiresOn sets an expiration
 | ||||||
|  | func (s *SessionState) SetExpiresOn(exp time.Time) { | ||||||
|  | 	s.ExpiresOn = &exp | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ExpiresIn sets an expiration a certain duration from CreatedAt.
 | ||||||
|  | // CreatedAt will be set to time.Now if it is unset.
 | ||||||
|  | func (s *SessionState) ExpiresIn(d time.Duration) { | ||||||
|  | 	if s.CreatedAt == nil { | ||||||
|  | 		s.CreatedAtNow() | ||||||
|  | 	} | ||||||
|  | 	exp := s.CreatedAt.Add(d) | ||||||
|  | 	s.ExpiresOn = &exp | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // IsExpired checks whether the session has expired
 | // IsExpired checks whether the session has expired
 | ||||||
| func (s *SessionState) IsExpired() bool { | func (s *SessionState) IsExpired() bool { | ||||||
| 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { | 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| 	return false | 	return false | ||||||
|  | @ -74,7 +95,7 @@ func (s *SessionState) IsExpired() bool { | ||||||
| // Age returns the age of a session
 | // Age returns the age of a session
 | ||||||
| func (s *SessionState) Age() time.Duration { | func (s *SessionState) Age() time.Duration { | ||||||
| 	if s.CreatedAt != nil && !s.CreatedAt.IsZero() { | 	if s.CreatedAt != nil && !s.CreatedAt.IsZero() { | ||||||
| 		return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) | 		return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) | ||||||
| 	} | 	} | ||||||
| 	return 0 | 	return 0 | ||||||
| } | } | ||||||
|  | @ -177,11 +198,6 @@ func DecodeSessionState(data []byte, c encryption.Cipher, compressed bool) (*Ses | ||||||
| 		return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) | 		return nil, fmt.Errorf("error unmarshalling data to session state: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = ss.validate() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return nil, err |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return &ss, nil | 	return &ss, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -235,35 +251,3 @@ func lz4Decompress(compressed []byte) ([]byte, error) { | ||||||
| 
 | 
 | ||||||
| 	return payload, nil | 	return payload, nil | ||||||
| } | } | ||||||
| 
 |  | ||||||
| // validate ensures the decoded session is non-empty and contains valid data
 |  | ||||||
| //
 |  | ||||||
| // Non-empty check is needed due to ensure the non-authenticated AES-CFB
 |  | ||||||
| // decryption doesn't result in garbage data that collides with a valid
 |  | ||||||
| // MessagePack header bytes (which MessagePack will unmarshal to an empty
 |  | ||||||
| // default SessionState). <1% chance, but observed with random test data.
 |  | ||||||
| //
 |  | ||||||
| // UTF-8 check ensures the strings are valid and not raw bytes overloaded
 |  | ||||||
| // into Latin-1 encoding. The occurs when legacy unencrypted fields are
 |  | ||||||
| // decrypted with AES-CFB which results in random bytes.
 |  | ||||||
| func (s *SessionState) validate() error { |  | ||||||
| 	for _, field := range []string{ |  | ||||||
| 		s.User, |  | ||||||
| 		s.Email, |  | ||||||
| 		s.PreferredUsername, |  | ||||||
| 		s.AccessToken, |  | ||||||
| 		s.IDToken, |  | ||||||
| 		s.RefreshToken, |  | ||||||
| 	} { |  | ||||||
| 		if !utf8.ValidString(field) { |  | ||||||
| 			return errors.New("invalid non-UTF8 field in session") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	empty := new(SessionState) |  | ||||||
| 	if reflect.DeepEqual(*s, *empty) { |  | ||||||
| 		return errors.New("invalid empty session unmarshalled") |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -16,6 +16,30 @@ func timePtr(t time.Time) *time.Time { | ||||||
| 	return &t | 	return &t | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestCreatedAtNow(t *testing.T) { | ||||||
|  | 	g := NewWithT(t) | ||||||
|  | 	ss := &SessionState{} | ||||||
|  | 
 | ||||||
|  | 	now := time.Unix(1234567890, 0) | ||||||
|  | 	ss.Clock.Set(now) | ||||||
|  | 
 | ||||||
|  | 	ss.CreatedAtNow() | ||||||
|  | 	g.Expect(*ss.CreatedAt).To(Equal(now)) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestExpiresIn(t *testing.T) { | ||||||
|  | 	g := NewWithT(t) | ||||||
|  | 	ss := &SessionState{} | ||||||
|  | 
 | ||||||
|  | 	now := time.Unix(1234567890, 0) | ||||||
|  | 	ss.Clock.Set(now) | ||||||
|  | 
 | ||||||
|  | 	ttl := time.Duration(743) * time.Second | ||||||
|  | 	ss.ExpiresIn(ttl) | ||||||
|  | 
 | ||||||
|  | 	g.Expect(*ss.ExpiresOn).To(Equal(ss.CreatedAt.Add(ttl))) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestString(t *testing.T) { | func TestString(t *testing.T) { | ||||||
| 	g := NewWithT(t) | 	g := NewWithT(t) | ||||||
| 	created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z") | 	created, err := time.Parse(time.RFC3339, "2000-01-01T00:00:00Z") | ||||||
|  |  | ||||||
|  | @ -63,13 +63,10 @@ func Reset() *clockapi.Mock { | ||||||
| // package.
 | // package.
 | ||||||
| type Clock struct { | type Clock struct { | ||||||
| 	mock *clockapi.Mock | 	mock *clockapi.Mock | ||||||
| 	sync.Mutex |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Set sets the Clock to a clock.Mock at the given time.Time
 | // Set sets the Clock to a clock.Mock at the given time.Time
 | ||||||
| func (c *Clock) Set(t time.Time) { | func (c *Clock) Set(t time.Time) { | ||||||
| 	c.Lock() |  | ||||||
| 	defer c.Unlock() |  | ||||||
| 	if c.mock == nil { | 	if c.mock == nil { | ||||||
| 		c.mock = clockapi.NewMock() | 		c.mock = clockapi.NewMock() | ||||||
| 	} | 	} | ||||||
|  | @ -79,8 +76,6 @@ func (c *Clock) Set(t time.Time) { | ||||||
| // Add moves clock forward time.Duration if it is mocked. It will error
 | // Add moves clock forward time.Duration if it is mocked. It will error
 | ||||||
| // if the clock is not mocked.
 | // if the clock is not mocked.
 | ||||||
| func (c *Clock) Add(d time.Duration) error { | func (c *Clock) Add(d time.Duration) error { | ||||||
| 	c.Lock() |  | ||||||
| 	defer c.Unlock() |  | ||||||
| 	if c.mock == nil { | 	if c.mock == nil { | ||||||
| 		return errors.New("clock not mocked") | 		return errors.New("clock not mocked") | ||||||
| 	} | 	} | ||||||
|  | @ -91,8 +86,6 @@ func (c *Clock) Add(d time.Duration) error { | ||||||
| // Reset removes local clock.Mock.  Returns any existing Mock if set in case
 | // Reset removes local clock.Mock.  Returns any existing Mock if set in case
 | ||||||
| // lingering time operations are attached to it.
 | // lingering time operations are attached to it.
 | ||||||
| func (c *Clock) Reset() *clockapi.Mock { | func (c *Clock) Reset() *clockapi.Mock { | ||||||
| 	c.Lock() |  | ||||||
| 	defer c.Unlock() |  | ||||||
| 	existing := c.mock | 	existing := c.mock | ||||||
| 	c.mock = nil | 	c.mock = nil | ||||||
| 	return existing | 	return existing | ||||||
|  |  | ||||||
|  | @ -11,25 +11,26 @@ import ( | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| // StoredSessionLoaderOptions cotnains all of the requirements to construct
 | // StoredSessionLoaderOptions contains all of the requirements to construct
 | ||||||
| // a stored session loader.
 | // a stored session loader.
 | ||||||
| // All options must be provided.
 | // All options must be provided.
 | ||||||
| type StoredSessionLoaderOptions struct { | type StoredSessionLoaderOptions struct { | ||||||
| 	// Session storage basckend
 | 	// Session storage backend
 | ||||||
| 	SessionStore sessionsapi.SessionStore | 	SessionStore sessionsapi.SessionStore | ||||||
| 
 | 
 | ||||||
| 	// How often should sessions be refreshed
 | 	// How often should sessions be refreshed
 | ||||||
| 	RefreshPeriod time.Duration | 	RefreshPeriod time.Duration | ||||||
| 
 | 
 | ||||||
| 	// Provider based sesssion refreshing
 | 	// Provider based session refreshing
 | ||||||
| 	RefreshSessionIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | 	RefreshSession func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
| 
 | 
 | ||||||
| 	// Provider based session validation.
 | 	// Provider based session validation.
 | ||||||
| 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | 	// If the sesssion is older than `RefreshPeriod` but the provider doesn't
 | ||||||
| 	// refresh it, we must re-validate using this validation.
 | 	// refresh it, we must re-validate using this validation.
 | ||||||
| 	ValidateSessionState func(context.Context, *sessionsapi.SessionState) bool | 	ValidateSession func(context.Context, *sessionsapi.SessionState) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | // NewStoredSessionLoader creates a new storedSessionLoader which loads
 | ||||||
|  | @ -40,8 +41,8 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor | ||||||
| 	ss := &storedSessionLoader{ | 	ss := &storedSessionLoader{ | ||||||
| 		store:            opts.SessionStore, | 		store:            opts.SessionStore, | ||||||
| 		refreshPeriod:    opts.RefreshPeriod, | 		refreshPeriod:    opts.RefreshPeriod, | ||||||
| 		refreshSessionWithProviderIfNeeded: opts.RefreshSessionIfNeeded, | 		sessionRefresher: opts.RefreshSession, | ||||||
| 		validateSessionState:               opts.ValidateSessionState, | 		sessionValidator: opts.ValidateSession, | ||||||
| 	} | 	} | ||||||
| 	return ss.loadSession | 	return ss.loadSession | ||||||
| } | } | ||||||
|  | @ -51,8 +52,8 @@ func NewStoredSessionLoader(opts *StoredSessionLoaderOptions) alice.Constructor | ||||||
| type storedSessionLoader struct { | type storedSessionLoader struct { | ||||||
| 	store            sessionsapi.SessionStore | 	store            sessionsapi.SessionStore | ||||||
| 	refreshPeriod    time.Duration | 	refreshPeriod    time.Duration | ||||||
| 	refreshSessionWithProviderIfNeeded func(context.Context, *sessionsapi.SessionState) (bool, error) | 	sessionRefresher func(context.Context, *sessionsapi.SessionState) (bool, error) | ||||||
| 	validateSessionState               func(context.Context, *sessionsapi.SessionState) bool | 	sessionValidator func(context.Context, *sessionsapi.SessionState) bool | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // loadSession attempts to load a session as identified by the request cookies.
 | // loadSession attempts to load a session as identified by the request cookies.
 | ||||||
|  | @ -108,49 +109,59 @@ func (s *storedSessionLoader) getValidatedSession(rw http.ResponseWriter, req *h | ||||||
| 
 | 
 | ||||||
| // refreshSessionIfNeeded will attempt to refresh a session if the session
 | // refreshSessionIfNeeded will attempt to refresh a session if the session
 | ||||||
| // is older than the refresh period.
 | // is older than the refresh period.
 | ||||||
| // It is assumed that if the provider refreshes the session, the session is now
 | // Success or fail, we will then validate the session.
 | ||||||
| // valid.
 |  | ||||||
| // If the session requires refreshing but the provider does not refresh it,
 |  | ||||||
| // we must validate the session to ensure that the returned session is still
 |  | ||||||
| // valid.
 |  | ||||||
| func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | func (s *storedSessionLoader) refreshSessionIfNeeded(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
| 	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { | 	if s.refreshPeriod <= time.Duration(0) || session.Age() < s.refreshPeriod { | ||||||
| 		// Refresh is disabled or the session is not old enough, do nothing
 | 		// Refresh is disabled or the session is not old enough, do nothing
 | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, s.refreshPeriod) | 	logger.Printf("Refreshing session - User: %s; SessionAge: %s", session.User, session.Age()) | ||||||
| 	refreshed, err := s.refreshSessionWithProvider(rw, req, session) | 	err := s.refreshSession(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		// If a preemptive refresh fails, we still keep the session
 | ||||||
|  | 		// if validateSession succeeds.
 | ||||||
|  | 		logger.Errorf("Unable to refresh session: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !refreshed { | 	// Validate all sessions after any Redeem/Refresh operation (fail or success)
 | ||||||
| 		// Session wasn't refreshed, so make sure it's still valid
 |  | ||||||
| 	return s.validateSession(req.Context(), session) | 	return s.validateSession(req.Context(), session) | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // refreshSession attempts to refresh the session with the provider
 | ||||||
|  | // and will save the session if it was updated.
 | ||||||
|  | func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) error { | ||||||
|  | 	refreshed, err := s.sessionRefresher(req.Context(), session) | ||||||
|  | 	if err != nil && !errors.Is(err, providers.ErrNotImplemented) { | ||||||
|  | 		return fmt.Errorf("error refreshing tokens: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// HACK:
 | ||||||
|  | 	// Providers that don't implement `RefreshSession` use the default
 | ||||||
|  | 	// implementation which returns `ErrNotImplemented`.
 | ||||||
|  | 	// Pretend it refreshed to reset the refresh timer so that `ValidateSession`
 | ||||||
|  | 	// isn't triggered every subsequent request and is only called once during
 | ||||||
|  | 	// this request.
 | ||||||
|  | 	if errors.Is(err, providers.ErrNotImplemented) { | ||||||
|  | 		refreshed = true | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// Session not refreshed, nothing to persist.
 | ||||||
|  | 	if !refreshed { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| // refreshSessionWithProvider attempts to refresh the sessinon with the provider
 | 	// If we refreshed, update the `CreatedAt` time to reset the refresh timer
 | ||||||
| // and will save the session if it was updated.
 | 	// (In case underlying provider implementations forget)
 | ||||||
| func (s *storedSessionLoader) refreshSessionWithProvider(rw http.ResponseWriter, req *http.Request, session *sessionsapi.SessionState) (bool, error) { | 	session.CreatedAtNow() | ||||||
| 	refreshed, err := s.refreshSessionWithProviderIfNeeded(req.Context(), session) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return false, fmt.Errorf("error refreshing access token: %v", err) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	if !refreshed { |  | ||||||
| 		return false, nil |  | ||||||
| 	} |  | ||||||
| 
 | 
 | ||||||
| 	// Because the session was refreshed, make sure to save it
 | 	// Because the session was refreshed, make sure to save it
 | ||||||
| 	err = s.store.Save(rw, req, session) | 	err = s.store.Save(rw, req, session) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | 		logger.PrintAuthf(session.Email, req, logger.AuthError, "error saving session: %v", err) | ||||||
| 		return false, fmt.Errorf("error saving session: %v", err) | 		return fmt.Errorf("error saving session: %v", err) | ||||||
| 	} | 	} | ||||||
| 	return true, nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // validateSession checks whether the session has expired and performs
 | // validateSession checks whether the session has expired and performs
 | ||||||
|  | @ -161,7 +172,7 @@ func (s *storedSessionLoader) validateSession(ctx context.Context, session *sess | ||||||
| 		return errors.New("session is expired") | 		return errors.New("session is expired") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if !s.validateSessionState(ctx, session) { | 	if !s.sessionValidator(ctx, session) { | ||||||
| 		return errors.New("session is invalid") | 		return errors.New("session is invalid") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -10,6 +10,8 @@ import ( | ||||||
| 
 | 
 | ||||||
| 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||||
|  | 	"github.com/oauth2-proxy/oauth2-proxy/v7/providers" | ||||||
| 	. "github.com/onsi/ginkgo" | 	. "github.com/onsi/ginkgo" | ||||||
| 	. "github.com/onsi/ginkgo/extensions/table" | 	. "github.com/onsi/ginkgo/extensions/table" | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
|  | @ -19,13 +21,15 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 	const ( | 	const ( | ||||||
| 		refresh        = "Refresh" | 		refresh        = "Refresh" | ||||||
| 		noRefresh      = "NoRefresh" | 		noRefresh      = "NoRefresh" | ||||||
|  | 		notImplemented = "NotImplemented" | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	var ctx = context.Background() | 	var ctx = context.Background() | ||||||
| 
 | 
 | ||||||
| 	Context("StoredSessionLoader", func() { | 	Context("StoredSessionLoader", func() { | ||||||
| 		createdPast := time.Now().Add(-5 * time.Minute) | 		now := time.Now() | ||||||
| 		createdFuture := time.Now().Add(5 * time.Minute) | 		createdPast := now.Add(-5 * time.Minute) | ||||||
|  | 		createdFuture := now.Add(5 * time.Minute) | ||||||
| 
 | 
 | ||||||
| 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 		var defaultRefreshFunc = func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 			switch ss.RefreshToken { | 			switch ss.RefreshToken { | ||||||
|  | @ -85,6 +89,14 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 			}, | 			}, | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | 		BeforeEach(func() { | ||||||
|  | 			clock.Set(now) | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
|  | 		AfterEach(func() { | ||||||
|  | 			clock.Reset() | ||||||
|  | 		}) | ||||||
|  | 
 | ||||||
| 		type storedSessionLoaderTableInput struct { | 		type storedSessionLoaderTableInput struct { | ||||||
| 			requestHeaders  http.Header | 			requestHeaders  http.Header | ||||||
| 			existingSession *sessionsapi.SessionState | 			existingSession *sessionsapi.SessionState | ||||||
|  | @ -111,8 +123,8 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				opts := &StoredSessionLoaderOptions{ | 				opts := &StoredSessionLoaderOptions{ | ||||||
| 					SessionStore:    in.store, | 					SessionStore:    in.store, | ||||||
| 					RefreshPeriod:   in.refreshPeriod, | 					RefreshPeriod:   in.refreshPeriod, | ||||||
| 					RefreshSessionIfNeeded: in.refreshSession, | 					RefreshSession:  in.refreshSession, | ||||||
| 					ValidateSessionState:   in.validateSession, | 					ValidateSession: in.validateSession, | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 				// Create the handler with a next handler that will capture the session
 | 				// Create the handler with a next handler that will capture the session
 | ||||||
|  | @ -208,6 +220,21 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				existingSession: nil, | 				existingSession: nil, | ||||||
| 				expectedSession: &sessionsapi.SessionState{ | 				expectedSession: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: "Refreshed", | 					RefreshToken: "Refreshed", | ||||||
|  | 					CreatedAt:    &now, | ||||||
|  | 					ExpiresOn:    &createdFuture, | ||||||
|  | 				}, | ||||||
|  | 				store:           defaultSessionStore, | ||||||
|  | 				refreshPeriod:   1 * time.Minute, | ||||||
|  | 				refreshSession:  defaultRefreshFunc, | ||||||
|  | 				validateSession: defaultValidateFunc, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider refresh fails but validation succeeds", storedSessionLoaderTableInput{ | ||||||
|  | 				requestHeaders: http.Header{ | ||||||
|  | 					"Cookie": []string{"_oauth2_proxy=RefreshError"}, | ||||||
|  | 				}, | ||||||
|  | 				existingSession: nil, | ||||||
|  | 				expectedSession: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: "RefreshError", | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 					ExpiresOn:    &createdFuture, | 					ExpiresOn:    &createdFuture, | ||||||
| 				}, | 				}, | ||||||
|  | @ -216,7 +243,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				refreshSession:  defaultRefreshFunc, | 				refreshSession:  defaultRefreshFunc, | ||||||
| 				validateSession: defaultValidateFunc, | 				validateSession: defaultValidateFunc, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider refresh fails", storedSessionLoaderTableInput{ | 			Entry("when the provider refresh fails and validation fails", storedSessionLoaderTableInput{ | ||||||
| 				requestHeaders: http.Header{ | 				requestHeaders: http.Header{ | ||||||
| 					"Cookie": []string{"_oauth2_proxy=RefreshError"}, | 					"Cookie": []string{"_oauth2_proxy=RefreshError"}, | ||||||
| 				}, | 				}, | ||||||
|  | @ -225,7 +252,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				store:           defaultSessionStore, | 				store:           defaultSessionStore, | ||||||
| 				refreshPeriod:   1 * time.Minute, | 				refreshPeriod:   1 * time.Minute, | ||||||
| 				refreshSession:  defaultRefreshFunc, | 				refreshSession:  defaultRefreshFunc, | ||||||
| 				validateSession: defaultValidateFunc, | 				validateSession: func(context.Context, *sessionsapi.SessionState) bool { return false }, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ | 			Entry("when the session is not refreshed and is no longer valid", storedSessionLoaderTableInput{ | ||||||
| 				requestHeaders: http.Header{ | 				requestHeaders: http.Header{ | ||||||
|  | @ -261,18 +288,20 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				s := &storedSessionLoader{ | 				s := &storedSessionLoader{ | ||||||
| 					refreshPeriod: in.refreshPeriod, | 					refreshPeriod: in.refreshPeriod, | ||||||
| 					store:         &fakeSessionStore{}, | 					store:         &fakeSessionStore{}, | ||||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 					sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 						refreshed = true | 						refreshed = true | ||||||
| 						switch ss.RefreshToken { | 						switch ss.RefreshToken { | ||||||
| 						case refresh: | 						case refresh: | ||||||
| 							return true, nil | 							return true, nil | ||||||
| 						case noRefresh: | 						case noRefresh: | ||||||
| 							return false, nil | 							return false, nil | ||||||
|  | 						case notImplemented: | ||||||
|  | 							return false, providers.ErrNotImplemented | ||||||
| 						default: | 						default: | ||||||
| 							return false, errors.New("error refreshing session") | 							return false, errors.New("error refreshing session") | ||||||
| 						} | 						} | ||||||
| 					}, | 					}, | ||||||
| 					validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | 					sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
| 						validated = true | 						validated = true | ||||||
| 						return ss.AccessToken != "Invalid" | 						return ss.AccessToken != "Invalid" | ||||||
| 					}, | 					}, | ||||||
|  | @ -326,7 +355,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     nil, | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: false, | 				expectValidated: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | 			Entry("when the session is not refreshed by the provider", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | @ -339,15 +368,25 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: true, | 				expectValidated: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider refresh fails", refreshSessionIfNeededTableInput{ | 			Entry("when the provider doesn't implement refresh but validation succeeds", refreshSessionIfNeededTableInput{ | ||||||
|  | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: notImplemented, | ||||||
|  | 					CreatedAt:    &createdPast, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr:     nil, | ||||||
|  | 				expectRefreshed: true, | ||||||
|  | 				expectValidated: true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider refresh fails but validation succeeds", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
| 				session: &sessionsapi.SessionState{ | 				session: &sessionsapi.SessionState{ | ||||||
| 					RefreshToken: "RefreshError", | 					RefreshToken: "RefreshError", | ||||||
| 					CreatedAt:    &createdPast, | 					CreatedAt:    &createdPast, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | 				expectedErr:     nil, | ||||||
| 				expectRefreshed: true, | 				expectRefreshed: true, | ||||||
| 				expectValidated: false, | 				expectValidated: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ | 			Entry("when the session is not refreshed by the provider and validation fails", refreshSessionIfNeededTableInput{ | ||||||
| 				refreshPeriod: 1 * time.Minute, | 				refreshPeriod: 1 * time.Minute, | ||||||
|  | @ -364,11 +403,10 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 		) | 		) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	Context("refreshSessionWithProvider", func() { | 	Context("refreshSession", func() { | ||||||
| 		type refreshSessionWithProviderTableInput struct { | 		type refreshSessionWithProviderTableInput struct { | ||||||
| 			session     *sessionsapi.SessionState | 			session     *sessionsapi.SessionState | ||||||
| 			expectedErr error | 			expectedErr error | ||||||
| 			expectRefreshed bool |  | ||||||
| 			expectSaved bool | 			expectSaved bool | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
|  | @ -388,12 +426,14 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 							return nil | 							return nil | ||||||
| 						}, | 						}, | ||||||
| 					}, | 					}, | ||||||
| 					refreshSessionWithProviderIfNeeded: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | 					sessionRefresher: func(_ context.Context, ss *sessionsapi.SessionState) (bool, error) { | ||||||
| 						switch ss.RefreshToken { | 						switch ss.RefreshToken { | ||||||
| 						case refresh: | 						case refresh: | ||||||
| 							return true, nil | 							return true, nil | ||||||
| 						case noRefresh: | 						case noRefresh: | ||||||
| 							return false, nil | 							return false, nil | ||||||
|  | 						case notImplemented: | ||||||
|  | 							return false, providers.ErrNotImplemented | ||||||
| 						default: | 						default: | ||||||
| 							return false, errors.New("error refreshing session") | 							return false, errors.New("error refreshing session") | ||||||
| 						} | 						} | ||||||
|  | @ -402,13 +442,12 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 
 | 
 | ||||||
| 				req := httptest.NewRequest("", "/", nil) | 				req := httptest.NewRequest("", "/", nil) | ||||||
| 				req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | 				req = middlewareapi.AddRequestScope(req, &middlewareapi.RequestScope{}) | ||||||
| 				refreshed, err := s.refreshSessionWithProvider(nil, req, in.session) | 				err := s.refreshSession(nil, req, in.session) | ||||||
| 				if in.expectedErr != nil { | 				if in.expectedErr != nil { | ||||||
| 					Expect(err).To(MatchError(in.expectedErr)) | 					Expect(err).To(MatchError(in.expectedErr)) | ||||||
| 				} else { | 				} else { | ||||||
| 					Expect(err).ToNot(HaveOccurred()) | 					Expect(err).ToNot(HaveOccurred()) | ||||||
| 				} | 				} | ||||||
| 				Expect(refreshed).To(Equal(in.expectRefreshed)) |  | ||||||
| 				Expect(saved).To(Equal(in.expectSaved)) | 				Expect(saved).To(Equal(in.expectSaved)) | ||||||
| 			}, | 			}, | ||||||
| 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | 			Entry("when the provider does not refresh the session", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -416,7 +455,6 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: noRefresh, | 					RefreshToken: noRefresh, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr: nil, | 				expectedErr: nil, | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectSaved: false, | 				expectSaved: false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | 			Entry("when the provider refreshes the session", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -424,7 +462,13 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					RefreshToken: refresh, | 					RefreshToken: refresh, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr: nil, | 				expectedErr: nil, | ||||||
| 				expectRefreshed: true, | 				expectSaved: true, | ||||||
|  | 			}), | ||||||
|  | 			Entry("when the provider doesn't implement refresh", refreshSessionWithProviderTableInput{ | ||||||
|  | 				session: &sessionsapi.SessionState{ | ||||||
|  | 					RefreshToken: notImplemented, | ||||||
|  | 				}, | ||||||
|  | 				expectedErr: nil, | ||||||
| 				expectSaved: true, | 				expectSaved: true, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | 			Entry("when the provider returns an error", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -433,8 +477,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					CreatedAt:    &now, | 					CreatedAt:    &now, | ||||||
| 					ExpiresOn:    &now, | 					ExpiresOn:    &now, | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr:     errors.New("error refreshing access token: error refreshing session"), | 				expectedErr: errors.New("error refreshing tokens: error refreshing session"), | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectSaved: false, | 				expectSaved: false, | ||||||
| 			}), | 			}), | ||||||
| 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | 			Entry("when the saving the session returns an error", refreshSessionWithProviderTableInput{ | ||||||
|  | @ -443,7 +486,6 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 					AccessToken:  "NoSave", | 					AccessToken:  "NoSave", | ||||||
| 				}, | 				}, | ||||||
| 				expectedErr: errors.New("error saving session: unable to save session"), | 				expectedErr: errors.New("error saving session: unable to save session"), | ||||||
| 				expectRefreshed: false, |  | ||||||
| 				expectSaved: true, | 				expectSaved: true, | ||||||
| 			}), | 			}), | ||||||
| 		) | 		) | ||||||
|  | @ -454,7 +496,7 @@ var _ = Describe("Stored Session Suite", func() { | ||||||
| 
 | 
 | ||||||
| 		BeforeEach(func() { | 		BeforeEach(func() { | ||||||
| 			s = &storedSessionLoader{ | 			s = &storedSessionLoader{ | ||||||
| 				validateSessionState: func(_ context.Context, ss *sessionsapi.SessionState) bool { | 				sessionValidator: func(_ context.Context, ss *sessionsapi.SessionState) bool { | ||||||
| 					return ss.AccessToken == "Valid" | 					return ss.AccessToken == "Valid" | ||||||
| 				}, | 				}, | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | @ -36,8 +36,7 @@ type SessionStore struct { | ||||||
| // within Cookies set on the HTTP response writer
 | // within Cookies set on the HTTP response writer
 | ||||||
| func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||||
| 	if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { | 	if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { | ||||||
| 		now := time.Now() | 		ss.CreatedAtNow() | ||||||
| 		ss.CreatedAt = &now |  | ||||||
| 	} | 	} | ||||||
| 	value, err := s.cookieForSession(ss) | 	value, err := s.cookieForSession(ss) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  |  | ||||||
|  | @ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager { | ||||||
| // from the persistent data store.
 | // from the persistent data store.
 | ||||||
| func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||||
| 	if s.CreatedAt == nil || s.CreatedAt.IsZero() { | 	if s.CreatedAt == nil || s.CreatedAt.IsZero() { | ||||||
| 		now := time.Now() | 		s.CreatedAtNow() | ||||||
| 		s.CreatedAt = &now |  | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	tckt, err := decodeTicketFromRequest(req, m.Options) | 	tckt, err := decodeTicketFromRequest(req, m.Options) | ||||||
|  |  | ||||||
|  | @ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() |  | ||||||
| 	expires := time.Unix(jsonResponse.ExpiresOn, 0) |  | ||||||
| 
 |  | ||||||
| 	session := &sessions.SessionState{ | 	session := &sessions.SessionState{ | ||||||
| 		AccessToken:  jsonResponse.AccessToken, | 		AccessToken:  jsonResponse.AccessToken, | ||||||
| 		IDToken:      jsonResponse.IDToken, | 		IDToken:      jsonResponse.IDToken, | ||||||
| 		CreatedAt:    &created, |  | ||||||
| 		ExpiresOn:    &expires, |  | ||||||
| 		RefreshToken: jsonResponse.RefreshToken, | 		RefreshToken: jsonResponse.RefreshToken, | ||||||
| 	} | 	} | ||||||
|  | 	session.CreatedAtNow() | ||||||
|  | 	session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | ||||||
| 
 | 
 | ||||||
| 	email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) | 	email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) | ||||||
| 
 | 
 | ||||||
|  | @ -239,28 +236,29 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st | ||||||
| 	return email, nil | 	return email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	origExpiration := s.ExpiresOn |  | ||||||
| 
 |  | ||||||
| 	err := p.redeemRefreshToken(ctx, s) | 	err := p.redeemRefreshToken(ctx, s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) |  | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||||
|  | 	clientSecret, err := p.GetClientSecret() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	params := url.Values{} | 	params := url.Values{} | ||||||
| 	params.Add("client_id", p.ClientID) | 	params.Add("client_id", p.ClientID) | ||||||
| 	params.Add("client_secret", p.ClientSecret) | 	params.Add("client_secret", clientSecret) | ||||||
| 	params.Add("refresh_token", s.RefreshToken) | 	params.Add("refresh_token", s.RefreshToken) | ||||||
| 	params.Add("grant_type", "refresh_token") | 	params.Add("grant_type", "refresh_token") | ||||||
| 
 | 
 | ||||||
|  | @ -278,18 +276,16 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess | ||||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&jsonResponse) | 		UnmarshalInto(&jsonResponse) | ||||||
| 
 |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	now := time.Now() |  | ||||||
| 	expires := time.Unix(jsonResponse.ExpiresOn, 0) |  | ||||||
| 	s.AccessToken = jsonResponse.AccessToken | 	s.AccessToken = jsonResponse.AccessToken | ||||||
| 	s.IDToken = jsonResponse.IDToken | 	s.IDToken = jsonResponse.IDToken | ||||||
| 	s.RefreshToken = jsonResponse.RefreshToken | 	s.RefreshToken = jsonResponse.RefreshToken | ||||||
| 	s.CreatedAt = &now | 
 | ||||||
| 	s.ExpiresOn = &expires | 	s.CreatedAtNow() | ||||||
|  | 	s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | ||||||
| 
 | 
 | ||||||
| 	email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) | 	email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) | ||||||
| 
 | 
 | ||||||
|  | @ -312,7 +308,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func makeAzureHeader(accessToken string) http.Header { | func makeAzureHeader(accessToken string) http.Header { | ||||||
|  |  | ||||||
|  | @ -340,17 +340,7 @@ func TestAzureProviderProtectedResourceConfigured(t *testing.T) { | ||||||
| 	assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) | 	assert.Contains(t, result, "resource="+url.QueryEscape("http://my.resource.test")) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAzureProviderNotRefreshWhenNotExpired(t *testing.T) { | func TestAzureProviderRefresh(t *testing.T) { | ||||||
| 	p := testAzureProvider("") |  | ||||||
| 
 |  | ||||||
| 	expires := time.Now().Add(time.Duration(1) * time.Hour) |  | ||||||
| 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} |  | ||||||
| 	refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) |  | ||||||
| 	assert.Equal(t, nil, err) |  | ||||||
| 	assert.False(t, refreshNeeded) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func TestAzureProviderRefreshWhenExpired(t *testing.T) { |  | ||||||
| 	email := "foo@example.com" | 	email := "foo@example.com" | ||||||
| 	idToken := idTokenClaims{Email: email} | 	idToken := idTokenClaims{Email: email} | ||||||
| 	idTokenString, err := newSignedTestIDToken(idToken) | 	idTokenString, err := newSignedTestIDToken(idToken) | ||||||
|  | @ -373,9 +363,10 @@ func TestAzureProviderRefreshWhenExpired(t *testing.T) { | ||||||
| 
 | 
 | ||||||
| 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | 	expires := time.Now().Add(time.Duration(-1) * time.Hour) | ||||||
| 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | 	session := &sessions.SessionState{AccessToken: "some_access_token", RefreshToken: "some_refresh_token", IDToken: "some_id_token", ExpiresOn: &expires} | ||||||
| 	refreshNeeded, err := p.RefreshSessionIfNeeded(context.Background(), session) | 
 | ||||||
|  | 	refreshed, err := p.RefreshSession(context.Background(), session) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.True(t, refreshNeeded) | 	assert.True(t, refreshed) | ||||||
| 	assert.NotEqual(t, session, nil) | 	assert.NotEqual(t, session, nil) | ||||||
| 	assert.Equal(t, "new_some_access_token", session.AccessToken) | 	assert.Equal(t, "new_some_access_token", session.AccessToken) | ||||||
| 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | 	assert.Equal(t, "new_some_refresh_token", session.RefreshToken) | ||||||
|  |  | ||||||
|  | @ -88,7 +88,7 @@ func (p *FacebookProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	return r.Email, nil | 	return r.Email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSession validates the AccessToken
 | ||||||
| func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | func (p *FacebookProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeOIDCHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -121,10 +121,9 @@ func (p *GitLabProvider) SetProjectScope() { | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | func (p *GitLabProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -139,10 +138,10 @@ func (p *GitLabProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	c := oauth2.Config{ | 	c := oauth2.Config{ | ||||||
|  | @ -164,13 +163,9 @@ func (p *GitLabProvider) redeemRefreshToken(ctx context.Context, s *sessions.Ses | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("unable to update session: %v", err) | 		return fmt.Errorf("unable to update session: %v", err) | ||||||
| 	} | 	} | ||||||
| 	s.AccessToken = newSession.AccessToken | 	*s = *newSession | ||||||
| 	s.IDToken = newSession.IDToken | 
 | ||||||
| 	s.RefreshToken = newSession.RefreshToken | 	return nil | ||||||
| 	s.CreatedAt = newSession.CreatedAt |  | ||||||
| 	s.ExpiresOn = newSession.ExpiresOn |  | ||||||
| 	s.Email = newSession.Email |  | ||||||
| 	return |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| type gitlabUserInfo struct { | type gitlabUserInfo struct { | ||||||
|  | @ -264,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	ss := &sessions.SessionState{ | ||||||
| 	return &sessions.SessionState{ |  | ||||||
| 		AccessToken:  token.AccessToken, | 		AccessToken:  token.AccessToken, | ||||||
| 		IDToken:      getIDToken(token), | 		IDToken:      getIDToken(token), | ||||||
| 		RefreshToken: token.RefreshToken, | 		RefreshToken: token.RefreshToken, | ||||||
| 		CreatedAt:    &created, | 	} | ||||||
| 		ExpiresOn:    &idToken.Expiry, | 
 | ||||||
| 	}, nil | 	ss.CreatedAtNow() | ||||||
|  | 	ss.SetExpiresOn(idToken.Expiry) | ||||||
|  | 
 | ||||||
|  | 	return ss, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSession checks that the session's IDToken is still valid
 | // ValidateSession checks that the session's IDToken is still valid
 | ||||||
|  |  | ||||||
|  | @ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	ss := &sessions.SessionState{ | ||||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) |  | ||||||
| 
 |  | ||||||
| 	return &sessions.SessionState{ |  | ||||||
| 		AccessToken:  jsonResponse.AccessToken, | 		AccessToken:  jsonResponse.AccessToken, | ||||||
| 		IDToken:      jsonResponse.IDToken, | 		IDToken:      jsonResponse.IDToken, | ||||||
| 		CreatedAt:    &created, |  | ||||||
| 		ExpiresOn:    &expires, |  | ||||||
| 		RefreshToken: jsonResponse.RefreshToken, | 		RefreshToken: jsonResponse.RefreshToken, | ||||||
| 		Email:        c.Email, | 		Email:        c.Email, | ||||||
| 		User:         c.Subject, | 		User:         c.Subject, | ||||||
| 	}, nil | 	} | ||||||
|  | 	ss.CreatedAtNow() | ||||||
|  | 	ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) | ||||||
|  | 
 | ||||||
|  | 	return ss, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // EnrichSession checks the listed Google Groups configured and adds any
 | // EnrichSession checks the listed Google Groups configured and adds any
 | ||||||
| // that the user is a member of to session.Groups.
 | // that the user is a member of to session.Groups.
 | ||||||
| func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error { | ||||||
| 	// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
 | 	// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
 | ||||||
| 	// reusing legacy `groupValidator`.
 | 	// reusing legacy `groupValidator`.
 | ||||||
| 	//
 | 	//
 | ||||||
|  | @ -266,14 +265,13 @@ func userInGroup(service *admin.Service, group string, email string) bool { | ||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new ID token if required
 | func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) | 	err := p.redeemRefreshToken(ctx, s) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return false, err | 		return false, err | ||||||
| 	} | 	} | ||||||
|  | @ -286,26 +284,20 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions | ||||||
| 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) | 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	origExpiration := s.ExpiresOn |  | ||||||
| 	expires := time.Now().Add(duration).Truncate(time.Second) |  | ||||||
| 	s.AccessToken = newToken |  | ||||||
| 	s.IDToken = newIDToken |  | ||||||
| 	s.ExpiresOn = &expires |  | ||||||
| 	logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) |  | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken string) (token string, idToken string, expires time.Duration, err error) { | func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||||
| 	// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
 | 	// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
 | ||||||
| 	clientSecret, err := p.GetClientSecret() | 	clientSecret, err := p.GetClientSecret() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	params := url.Values{} | 	params := url.Values{} | ||||||
| 	params.Add("client_id", p.ClientID) | 	params.Add("client_id", p.ClientID) | ||||||
| 	params.Add("client_secret", clientSecret) | 	params.Add("client_secret", clientSecret) | ||||||
| 	params.Add("refresh_token", refreshToken) | 	params.Add("refresh_token", s.RefreshToken) | ||||||
| 	params.Add("grant_type", "refresh_token") | 	params.Add("grant_type", "refresh_token") | ||||||
| 
 | 
 | ||||||
| 	var data struct { | 	var data struct { | ||||||
|  | @ -322,11 +314,14 @@ func (p *GoogleProvider) redeemRefreshToken(ctx context.Context, refreshToken st | ||||||
| 		Do(). | 		Do(). | ||||||
| 		UnmarshalInto(&data) | 		UnmarshalInto(&data) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", "", 0, err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	token = data.AccessToken | 	s.AccessToken = data.AccessToken | ||||||
| 	idToken = data.IDToken | 	s.IDToken = data.IDToken | ||||||
| 	expires = time.Duration(data.ExpiresIn) * time.Second | 
 | ||||||
| 	return | 	s.CreatedAtNow() | ||||||
|  | 	s.ExpiresIn(time.Duration(data.ExpiresIn) * time.Second) | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -93,7 +93,7 @@ func (p *LinkedInProvider) GetEmailAddress(ctx context.Context, s *sessions.Sess | ||||||
| 	return email, nil | 	return email, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // ValidateSessionState validates the AccessToken
 | // ValidateSession validates the AccessToken
 | ||||||
| func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | func (p *LinkedInProvider) ValidateSession(ctx context.Context, s *sessions.SessionState) bool { | ||||||
| 	return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) | 	return validateToken(ctx, p, s.AccessToken, makeLinkedInHeader(s.AccessToken)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||||
| func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) { | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, ErrMissingCode | 		return nil, ErrMissingCode | ||||||
| 	} | 	} | ||||||
|  | @ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	session := &sessions.SessionState{ | ||||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) |  | ||||||
| 
 |  | ||||||
| 	// Store the data that we found in the session state
 |  | ||||||
| 	return &sessions.SessionState{ |  | ||||||
| 		AccessToken: jsonResponse.AccessToken, | 		AccessToken: jsonResponse.AccessToken, | ||||||
| 		IDToken:     jsonResponse.IDToken, | 		IDToken:     jsonResponse.IDToken, | ||||||
| 		CreatedAt:   &created, |  | ||||||
| 		ExpiresOn:   &expires, |  | ||||||
| 		Email:       email, | 		Email:       email, | ||||||
| 	}, nil | 	} | ||||||
|  | 
 | ||||||
|  | 	session.CreatedAtNow() | ||||||
|  | 	session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) | ||||||
|  | 
 | ||||||
|  | 	return session, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // GetLoginURL overrides GetLoginURL to add login.gov parameters
 | // GetLoginURL overrides GetLoginURL to add login.gov parameters
 | ||||||
|  |  | ||||||
|  | @ -143,10 +143,9 @@ func (p *OIDCProvider) ValidateSession(ctx context.Context, s *sessions.SessionS | ||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||||
| // RefreshToken to fetch a new Access Token (and optional ID token) if required
 | func (p *OIDCProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | 	if s == nil || s.RefreshToken == "" { | ||||||
| 	if s == nil || (s.ExpiresOn != nil && s.ExpiresOn.After(time.Now())) || s.RefreshToken == "" { |  | ||||||
| 		return false, nil | 		return false, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -155,7 +154,6 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.S | ||||||
| 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	logger.Printf("refreshed session: %s", s) |  | ||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -227,7 +225,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) | ||||||
| 	ss.AccessToken = token | 	ss.AccessToken = token | ||||||
| 	ss.IDToken = token | 	ss.IDToken = token | ||||||
| 	ss.RefreshToken = "" | 	ss.RefreshToken = "" | ||||||
| 	ss.ExpiresOn = &idToken.Expiry | 
 | ||||||
|  | 	ss.CreatedAtNow() | ||||||
|  | 	ss.SetExpiresOn(idToken.Expiry) | ||||||
| 
 | 
 | ||||||
| 	return ss, nil | 	return ss, nil | ||||||
| } | } | ||||||
|  | @ -257,9 +257,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r | ||||||
| 	ss.RefreshToken = token.RefreshToken | 	ss.RefreshToken = token.RefreshToken | ||||||
| 	ss.IDToken = getIDToken(token) | 	ss.IDToken = getIDToken(token) | ||||||
| 
 | 
 | ||||||
| 	created := time.Now() | 	ss.CreatedAtNow() | ||||||
| 	ss.CreatedAt = &created | 	ss.SetExpiresOn(token.Expiry) | ||||||
| 	ss.ExpiresOn = &token.Expiry |  | ||||||
| 
 | 
 | ||||||
| 	return ss, nil | 	return ss, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -487,7 +487,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithoutIdToken(t *testing.T) { | ||||||
| 		User:         "11223344", | 		User:         "11223344", | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) | 	refreshed, err := provider.RefreshSession(context.Background(), existingSession) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, refreshed, true) | 	assert.Equal(t, refreshed, true) | ||||||
| 	assert.Equal(t, "janedoe@example.com", existingSession.Email) | 	assert.Equal(t, "janedoe@example.com", existingSession.Email) | ||||||
|  | @ -520,7 +520,7 @@ func TestOIDCProviderRefreshSessionIfNeededWithIdToken(t *testing.T) { | ||||||
| 		Email:        "changeit", | 		Email:        "changeit", | ||||||
| 		User:         "changeit", | 		User:         "changeit", | ||||||
| 	} | 	} | ||||||
| 	refreshed, err := provider.RefreshSessionIfNeeded(context.Background(), existingSession) | 	refreshed, err := provider.RefreshSession(context.Background(), existingSession) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, refreshed, true) | 	assert.Equal(t, refreshed, true) | ||||||
| 	assert.Equal(t, defaultIDToken.Email, existingSession.Email) | 	assert.Equal(t, defaultIDToken.Email, existingSession.Email) | ||||||
|  |  | ||||||
|  | @ -6,7 +6,6 @@ import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"time" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||||
|  | @ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|  | 	// TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration
 | ||||||
| 	if token := values.Get("access_token"); token != "" { | 	if token := values.Get("access_token"); token != "" { | ||||||
| 		created := time.Now() | 		ss := &sessions.SessionState{ | ||||||
| 		return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil | 			AccessToken: token, | ||||||
|  | 		} | ||||||
|  | 		ss.CreatedAtNow() | ||||||
|  | 		return ss, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return nil, fmt.Errorf("no access token found %s", result.Body()) | 	return nil, fmt.Errorf("no access token found %s", result.Body()) | ||||||
|  | @ -126,10 +129,9 @@ func (p *ProviderData) ValidateSession(ctx context.Context, s *sessions.SessionS | ||||||
| 	return validateToken(ctx, p, s.AccessToken, nil) | 	return validateToken(ctx, p, s.AccessToken, nil) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // RefreshSessionIfNeeded should refresh the user's session if required and
 | // RefreshSession refreshes the user's session
 | ||||||
| // do nothing if a refresh is not required
 | func (p *ProviderData) RefreshSession(_ context.Context, _ *sessions.SessionState) (bool, error) { | ||||||
| func (p *ProviderData) RefreshSessionIfNeeded(_ context.Context, _ *sessions.SessionState) (bool, error) { | 	return false, ErrNotImplemented | ||||||
| 	return false, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // CreateSessionFromToken converts Bearer IDTokens into sessions
 | // CreateSessionFromToken converts Bearer IDTokens into sessions
 | ||||||
|  |  | ||||||
|  | @ -14,12 +14,20 @@ import ( | ||||||
| func TestRefresh(t *testing.T) { | func TestRefresh(t *testing.T) { | ||||||
| 	p := &ProviderData{} | 	p := &ProviderData{} | ||||||
| 
 | 
 | ||||||
| 	expires := time.Now().Add(time.Duration(-11) * time.Minute) | 	now := time.Unix(1234567890, 10) | ||||||
| 	refreshed, err := p.RefreshSessionIfNeeded(context.Background(), &sessions.SessionState{ | 	expires := time.Unix(1234567890, 0) | ||||||
| 		ExpiresOn: &expires, | 
 | ||||||
| 	}) | 	ss := &sessions.SessionState{} | ||||||
| 	assert.Equal(t, false, refreshed) | 	ss.Clock.Set(now) | ||||||
| 	assert.Equal(t, nil, err) | 	ss.SetExpiresOn(expires) | ||||||
|  | 
 | ||||||
|  | 	refreshed, err := p.RefreshSession(context.Background(), ss) | ||||||
|  | 	assert.False(t, refreshed) | ||||||
|  | 	assert.Equal(t, ErrNotImplemented, err) | ||||||
|  | 
 | ||||||
|  | 	refreshed, err = p.RefreshSession(context.Background(), nil) | ||||||
|  | 	assert.False(t, refreshed) | ||||||
|  | 	assert.Equal(t, ErrNotImplemented, err) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestAcrValuesNotConfigured(t *testing.T) { | func TestAcrValuesNotConfigured(t *testing.T) { | ||||||
|  |  | ||||||
|  | @ -9,14 +9,14 @@ import ( | ||||||
| // Provider represents an upstream identity provider implementation
 | // Provider represents an upstream identity provider implementation
 | ||||||
| type Provider interface { | type Provider interface { | ||||||
| 	Data() *ProviderData | 	Data() *ProviderData | ||||||
|  | 	GetLoginURL(redirectURI, finalRedirect string, nonce string) string | ||||||
|  | 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) | ||||||
| 	// Deprecated: Migrate to EnrichSession
 | 	// Deprecated: Migrate to EnrichSession
 | ||||||
| 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | 	GetEmailAddress(ctx context.Context, s *sessions.SessionState) (string, error) | ||||||
| 	GetLoginURL(redirectURI, state, nonce string) string |  | ||||||
| 	Redeem(ctx context.Context, redirectURI, code string) (*sessions.SessionState, error) |  | ||||||
| 	EnrichSession(ctx context.Context, s *sessions.SessionState) error | 	EnrichSession(ctx context.Context, s *sessions.SessionState) error | ||||||
| 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) | 	Authorize(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
| 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | 	ValidateSession(ctx context.Context, s *sessions.SessionState) bool | ||||||
| 	RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) | 	RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) | ||||||
| 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) | 	CreateSessionFromToken(ctx context.Context, token string) (*sessions.SessionState, error) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue