Manage session time fields centrally
This commit is contained in:
		
							parent
							
								
									7e80e5596b
								
							
						
					
					
						commit
						7fa6d2d024
					
				|  | @ -786,6 +786,15 @@ func (p *OAuthProxy) redeemCode(req *http.Request) (*sessionsapi.SessionState, e | |||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Force setting these in case the Provider didn't
 | ||||
| 	if s.CreatedAt == nil { | ||||
| 		s.CreatedAtNow() | ||||
| 	} | ||||
| 	if s.ExpiresOn == nil { | ||||
| 		s.ExpiresIn(p.CookieOptions.Expire) | ||||
| 	} | ||||
| 
 | ||||
| 	return s, nil | ||||
| } | ||||
| 
 | ||||
|  | @ -861,9 +870,9 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { | |||
| 
 | ||||
| // See https://developers.google.com/web/fundamentals/performance/optimizing-content-efficiency/http-caching?hl=en
 | ||||
| var noCacheHeaders = map[string]string{ | ||||
| 	"Expires":         time.Unix(0, 0).Format(time.RFC1123), | ||||
| 	"Cache-Control":   "no-cache, no-store, must-revalidate, max-age=0", | ||||
| 	"X-Accel-Expires": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 | ||||
| 	"Expires":        time.Unix(0, 0).Format(time.RFC1123), | ||||
| 	"Cache-Control":  "no-cache, no-store, must-revalidate, max-age=0", | ||||
| 	"X-Accel-Expire": "0", // https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/
 | ||||
| } | ||||
| 
 | ||||
| // prepareNoCache prepares headers for preventing browser caching.
 | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ import ( | |||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" | ||||
| 	"github.com/pierrec/lz4" | ||||
| 	"github.com/vmihailenco/msgpack/v4" | ||||
|  | @ -32,7 +33,8 @@ type SessionState struct { | |||
| 	Groups            []string `msgpack:"g,omitempty"` | ||||
| 	PreferredUsername string   `msgpack:"pu,omitempty"` | ||||
| 
 | ||||
| 	Lock Lock `msgpack:"-"` | ||||
| 	Clock clock.Clock `msgpack:"-"` | ||||
| 	Lock  Lock        `msgpack:"-"` | ||||
| } | ||||
| 
 | ||||
| func (s *SessionState) ObtainLock(ctx context.Context, expiration time.Duration) error { | ||||
|  | @ -63,9 +65,30 @@ func (s *SessionState) PeekLock(ctx context.Context) (bool, error) { | |||
| 	return s.Lock.Peek(ctx) | ||||
| } | ||||
| 
 | ||||
| // CreatedAtNow sets a SessionState's CreatedAt to now
 | ||||
| func (s *SessionState) CreatedAtNow() { | ||||
| 	now := s.Clock.Now() | ||||
| 	s.CreatedAt = &now | ||||
| } | ||||
| 
 | ||||
| // SetExpiresOn sets an expiration
 | ||||
| func (s *SessionState) SetExpiresOn(exp time.Time) { | ||||
| 	s.ExpiresOn = &exp | ||||
| } | ||||
| 
 | ||||
| // ExpiresIn sets an expiration a certain duration from CreatedAt.
 | ||||
| // CreatedAt will be set to time.Now if it is unset.
 | ||||
| func (s *SessionState) ExpiresIn(d time.Duration) { | ||||
| 	if s.CreatedAt == nil { | ||||
| 		s.CreatedAtNow() | ||||
| 	} | ||||
| 	exp := s.CreatedAt.Add(d) | ||||
| 	s.ExpiresOn = &exp | ||||
| } | ||||
| 
 | ||||
| // IsExpired checks whether the session has expired
 | ||||
| func (s *SessionState) IsExpired() bool { | ||||
| 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { | ||||
| 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(s.Clock.Now()) { | ||||
| 		return true | ||||
| 	} | ||||
| 	return false | ||||
|  | @ -74,7 +97,7 @@ func (s *SessionState) IsExpired() bool { | |||
| // Age returns the age of a session
 | ||||
| func (s *SessionState) Age() time.Duration { | ||||
| 	if s.CreatedAt != nil && !s.CreatedAt.IsZero() { | ||||
| 		return time.Now().Truncate(time.Second).Sub(*s.CreatedAt) | ||||
| 		return s.Clock.Now().Truncate(time.Second).Sub(*s.CreatedAt) | ||||
| 	} | ||||
| 	return 0 | ||||
| } | ||||
|  |  | |||
|  | @ -142,8 +142,7 @@ func (s *storedSessionLoader) refreshSession(rw http.ResponseWriter, req *http.R | |||
| 	} | ||||
| 
 | ||||
| 	// If we refreshed, update the `CreatedAt` time to reset the refresh timer
 | ||||
| 	// TODO: Implement
 | ||||
| 	// session.CreatedAtNow()
 | ||||
| 	session.CreatedAtNow() | ||||
| 
 | ||||
| 	// Because the session was refreshed, make sure to save it
 | ||||
| 	err = s.store.Save(rw, req, session) | ||||
|  |  | |||
|  | @ -36,8 +36,7 @@ type SessionStore struct { | |||
| // within Cookies set on the HTTP response writer
 | ||||
| func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { | ||||
| 	if ss.CreatedAt == nil || ss.CreatedAt.IsZero() { | ||||
| 		now := time.Now() | ||||
| 		ss.CreatedAt = &now | ||||
| 		ss.CreatedAtNow() | ||||
| 	} | ||||
| 	value, err := s.cookieForSession(ss) | ||||
| 	if err != nil { | ||||
|  |  | |||
|  | @ -30,8 +30,7 @@ func NewManager(store Store, cookieOpts *options.Cookie) *Manager { | |||
| // from the persistent data store.
 | ||||
| func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { | ||||
| 	if s.CreatedAt == nil || s.CreatedAt.IsZero() { | ||||
| 		now := time.Now() | ||||
| 		s.CreatedAt = &now | ||||
| 		s.CreatedAtNow() | ||||
| 	} | ||||
| 
 | ||||
| 	tckt, err := decodeTicketFromRequest(req, m.Options) | ||||
|  |  | |||
|  | @ -142,16 +142,13 @@ func (p *AzureProvider) Redeem(ctx context.Context, redirectURL, code string) (* | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	expires := time.Unix(jsonResponse.ExpiresOn, 0) | ||||
| 
 | ||||
| 	session := &sessions.SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		CreatedAt:    &created, | ||||
| 		ExpiresOn:    &expires, | ||||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 	} | ||||
| 	session.CreatedAtNow() | ||||
| 	session.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | ||||
| 
 | ||||
| 	email, err := p.verifyTokenAndExtractEmail(ctx, session.IDToken) | ||||
| 
 | ||||
|  | @ -239,10 +236,9 @@ func (p *AzureProvider) verifyTokenAndExtractEmail(ctx context.Context, token st | |||
| 	return email, nil | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
| // RefreshToken to fetch a new ID token if required
 | ||||
| func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||
| // RefreshSession uses the RefreshToken to fetch new Access and ID Tokens
 | ||||
| func (p *AzureProvider) RefreshSession(ctx context.Context, s *sessions.SessionState) (bool, error) { | ||||
| 	if s == nil || s.RefreshToken == "" { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
|  | @ -257,7 +253,7 @@ func (p *AzureProvider) RefreshSessionIfNeeded(ctx context.Context, s *sessions. | |||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) (err error) { | ||||
| func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.SessionState) error { | ||||
| 	params := url.Values{} | ||||
| 	params.Add("client_id", p.ClientID) | ||||
| 	params.Add("client_secret", p.ClientSecret) | ||||
|  | @ -271,25 +267,23 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess | |||
| 		IDToken      string `json:"id_token"` | ||||
| 	} | ||||
| 
 | ||||
| 	err = requests.New(p.RedeemURL.String()). | ||||
| 	err := requests.New(p.RedeemURL.String()). | ||||
| 		WithContext(ctx). | ||||
| 		WithMethod("POST"). | ||||
| 		WithBody(bytes.NewBufferString(params.Encode())). | ||||
| 		SetHeader("Content-Type", "application/x-www-form-urlencoded"). | ||||
| 		Do(). | ||||
| 		UnmarshalInto(&jsonResponse) | ||||
| 
 | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	now := time.Now() | ||||
| 	expires := time.Unix(jsonResponse.ExpiresOn, 0) | ||||
| 	s.AccessToken = jsonResponse.AccessToken | ||||
| 	s.IDToken = jsonResponse.IDToken | ||||
| 	s.RefreshToken = jsonResponse.RefreshToken | ||||
| 	s.CreatedAt = &now | ||||
| 	s.ExpiresOn = &expires | ||||
| 
 | ||||
| 	s.CreatedAtNow() | ||||
| 	s.SetExpiresOn(time.Unix(jsonResponse.ExpiresOn, 0)) | ||||
| 
 | ||||
| 	email, err := p.verifyTokenAndExtractEmail(ctx, s.IDToken) | ||||
| 
 | ||||
|  | @ -312,7 +306,7 @@ func (p *AzureProvider) redeemRefreshToken(ctx context.Context, s *sessions.Sess | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func makeAzureHeader(accessToken string) http.Header { | ||||
|  |  | |||
|  | @ -259,14 +259,16 @@ func (p *GitLabProvider) createSession(ctx context.Context, token *oauth2.Token) | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	return &sessions.SessionState{ | ||||
| 	ss := &sessions.SessionState{ | ||||
| 		AccessToken:  token.AccessToken, | ||||
| 		IDToken:      getIDToken(token), | ||||
| 		RefreshToken: token.RefreshToken, | ||||
| 		CreatedAt:    &created, | ||||
| 		ExpiresOn:    &idToken.Expiry, | ||||
| 	}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	ss.CreatedAtNow() | ||||
| 	ss.SetExpiresOn(idToken.Expiry) | ||||
| 
 | ||||
| 	return ss, nil | ||||
| } | ||||
| 
 | ||||
| // ValidateSession checks that the session's IDToken is still valid
 | ||||
|  |  | |||
|  | @ -163,23 +163,22 @@ func (p *GoogleProvider) Redeem(ctx context.Context, redirectURL, code string) ( | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | ||||
| 
 | ||||
| 	return &sessions.SessionState{ | ||||
| 	ss := &sessions.SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		CreatedAt:    &created, | ||||
| 		ExpiresOn:    &expires, | ||||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 		Email:        c.Email, | ||||
| 		User:         c.Subject, | ||||
| 	}, nil | ||||
| 	} | ||||
| 	ss.CreatedAtNow() | ||||
| 	ss.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) | ||||
| 
 | ||||
| 	return ss, nil | ||||
| } | ||||
| 
 | ||||
| // EnrichSession checks the listed Google Groups configured and adds any
 | ||||
| // that the user is a member of to session.Groups.
 | ||||
| func (p *GoogleProvider) EnrichSession(ctx context.Context, s *sessions.SessionState) error { | ||||
| func (p *GoogleProvider) EnrichSession(_ context.Context, s *sessions.SessionState) error { | ||||
| 	// TODO (@NickMeves) - Move to pure EnrichSession logic and stop
 | ||||
| 	// reusing legacy `groupValidator`.
 | ||||
| 	//
 | ||||
|  | @ -272,7 +271,7 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session | |||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	newToken, newIDToken, duration, err := p.redeemRefreshToken(ctx, s.RefreshToken) | ||||
| 	newToken, newIDToken, ttl, err := p.redeemRefreshToken(ctx, s.RefreshToken) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | @ -285,12 +284,12 @@ func (p *GoogleProvider) RefreshSession(ctx context.Context, s *sessions.Session | |||
| 		return false, fmt.Errorf("%s is no longer in the group(s)", s.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	origExpiration := s.ExpiresOn | ||||
| 	expires := time.Now().Add(duration).Truncate(time.Second) | ||||
| 	s.AccessToken = newToken | ||||
| 	s.IDToken = newIDToken | ||||
| 	s.ExpiresOn = &expires | ||||
| 	logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) | ||||
| 
 | ||||
| 	s.CreatedAtNow() | ||||
| 	s.ExpiresIn(ttl) | ||||
| 
 | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -159,7 +159,7 @@ func emailFromUserInfo(ctx context.Context, accessToken string, userInfoEndpoint | |||
| } | ||||
| 
 | ||||
| // Redeem exchanges the OAuth2 authentication token for an ID token
 | ||||
| func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) (*sessions.SessionState, error) { | ||||
| func (p *LoginGovProvider) Redeem(ctx context.Context, _, code string) (*sessions.SessionState, error) { | ||||
| 	if code == "" { | ||||
| 		return nil, ErrMissingCode | ||||
| 	} | ||||
|  | @ -214,17 +214,16 @@ func (p *LoginGovProvider) Redeem(ctx context.Context, redirectURL, code string) | |||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	expires := time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second) | ||||
| 
 | ||||
| 	// Store the data that we found in the session state
 | ||||
| 	return &sessions.SessionState{ | ||||
| 	session := &sessions.SessionState{ | ||||
| 		AccessToken: jsonResponse.AccessToken, | ||||
| 		IDToken:     jsonResponse.IDToken, | ||||
| 		CreatedAt:   &created, | ||||
| 		ExpiresOn:   &expires, | ||||
| 		Email:       email, | ||||
| 	}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	session.CreatedAtNow() | ||||
| 	session.ExpiresIn(time.Duration(jsonResponse.ExpiresIn) * time.Second) | ||||
| 
 | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| // GetLoginURL overrides GetLoginURL to add login.gov parameters
 | ||||
|  |  | |||
|  | @ -226,7 +226,9 @@ func (p *OIDCProvider) CreateSessionFromToken(ctx context.Context, token string) | |||
| 	ss.AccessToken = token | ||||
| 	ss.IDToken = token | ||||
| 	ss.RefreshToken = "" | ||||
| 	ss.ExpiresOn = &idToken.Expiry | ||||
| 
 | ||||
| 	ss.CreatedAtNow() | ||||
| 	ss.SetExpiresOn(idToken.Expiry) | ||||
| 
 | ||||
| 	return ss, nil | ||||
| } | ||||
|  | @ -256,9 +258,8 @@ func (p *OIDCProvider) createSession(ctx context.Context, token *oauth2.Token, r | |||
| 	ss.RefreshToken = token.RefreshToken | ||||
| 	ss.IDToken = getIDToken(token) | ||||
| 
 | ||||
| 	created := time.Now() | ||||
| 	ss.CreatedAt = &created | ||||
| 	ss.ExpiresOn = &token.Expiry | ||||
| 	ss.CreatedAtNow() | ||||
| 	ss.SetExpiresOn(token.Expiry) | ||||
| 
 | ||||
| 	return ss, nil | ||||
| } | ||||
|  |  | |||
|  | @ -6,7 +6,6 @@ import ( | |||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" | ||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" | ||||
|  | @ -85,9 +84,13 @@ func (p *ProviderData) Redeem(ctx context.Context, redirectURL, code string) (*s | |||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	// TODO (@NickMeves): Uses OAuth `expires_in` to set an expiration
 | ||||
| 	if token := values.Get("access_token"); token != "" { | ||||
| 		created := time.Now() | ||||
| 		return &sessions.SessionState{AccessToken: token, CreatedAt: &created}, nil | ||||
| 		ss := &sessions.SessionState{ | ||||
| 			AccessToken: token, | ||||
| 		} | ||||
| 		ss.CreatedAtNow() | ||||
| 		return ss, nil | ||||
| 	} | ||||
| 
 | ||||
| 	return nil, fmt.Errorf("no access token found %s", result.Body()) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue