Merge pull request #14 from pusher/oidc
OIDC ID Token, Authorization Headers, Refreshing and Verification
This commit is contained in:
		
						commit
						440d2f32bf
					
				|  | @ -2,6 +2,11 @@ | |||
| 
 | ||||
| ## Changes since v3.0.0 | ||||
| 
 | ||||
| - [#14](https://github.com/pusher/oauth2_proxy/pull/14) OIDC ID Token, Authorization Headers, Refreshing and Verification (@joelspeed) | ||||
|   - Implement `pass-authorization-header` and `set-authorization-header` flags | ||||
|   - Implement token refreshing in OIDC provider | ||||
|   - Split cookies larger than 4k limit into multiple cookies | ||||
|   - Implement token validation in OIDC provider | ||||
| - [#21](https://github.com/pusher/oauth2_proxy/pull/21) Docker Improvement (@yaegashi) | ||||
|   - Move Docker base image from debian to alpine | ||||
|   - Install ca-certificates in docker image | ||||
|  |  | |||
|  | @ -212,6 +212,7 @@ Usage of oauth2_proxy: | |||
|   -https-address string: <addr>:<port> to listen on for HTTPS clients (default ":443") | ||||
|   -login-url string: Authentication endpoint | ||||
|   -pass-access-token: pass OAuth access_token to upstream via X-Forwarded-Access-Token header | ||||
|   -pass-authorization-header: pass OIDC IDToken to upstream via Authorization Bearer header | ||||
|   -pass-basic-auth: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream (default true) | ||||
|   -pass-host-header: pass the request Host Header to upstream (default true) | ||||
|   -pass-user-headers: pass X-Forwarded-User and X-Forwarded-Email information to upstream (default true) | ||||
|  | @ -225,6 +226,7 @@ Usage of oauth2_proxy: | |||
|   -resource string: The resource that is protected (Azure AD only) | ||||
|   -scope string: OAuth scope specification | ||||
|   -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) | ||||
|   -set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode) | ||||
|   -signature-key string: GAP-Signature request signature key (algorithm:secretkey) | ||||
|   -skip-auth-preflight: will skip authentication for OPTIONS requests | ||||
|   -skip-auth-regex value: bypass authentication for requests path's that match (may be given multiple times) | ||||
|  |  | |||
							
								
								
									
										2
									
								
								main.go
								
								
								
								
							
							
						
						
									
										2
									
								
								main.go
								
								
								
								
							|  | @ -37,6 +37,8 @@ func main() { | |||
| 	flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") | ||||
| 	flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") | ||||
| 	flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") | ||||
| 	flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream") | ||||
| 	flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)") | ||||
| 	flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") | ||||
| 	flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") | ||||
| 	flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") | ||||
|  |  | |||
							
								
								
									
										127
									
								
								oauthproxy.go
								
								
								
								
							
							
						
						
									
										127
									
								
								oauthproxy.go
								
								
								
								
							|  | @ -26,6 +26,10 @@ const ( | |||
| 
 | ||||
| 	httpScheme  = "http" | ||||
| 	httpsScheme = "https" | ||||
| 
 | ||||
| 	// Cookies are limited to 4kb including the length of the cookie name,
 | ||||
| 	// the cookie name can be up to 256 bytes
 | ||||
| 	maxCookieLength = 3840 | ||||
| ) | ||||
| 
 | ||||
| // SignatureHeaders contains the headers to be signed by the hmac algorithm
 | ||||
|  | @ -76,6 +80,8 @@ type OAuthProxy struct { | |||
| 	PassUserHeaders     bool | ||||
| 	BasicAuthPassword   string | ||||
| 	PassAccessToken     bool | ||||
| 	SetAuthorization    bool | ||||
| 	PassAuthorization   bool | ||||
| 	CookieCipher        *cookie.Cipher | ||||
| 	skipAuthRegex       []string | ||||
| 	skipAuthPreflight   bool | ||||
|  | @ -183,7 +189,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | |||
| 	log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh) | ||||
| 
 | ||||
| 	var cipher *cookie.Cipher | ||||
| 	if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { | ||||
| 	if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) { | ||||
| 		var err error | ||||
| 		cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) | ||||
| 		if err != nil { | ||||
|  | @ -222,6 +228,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { | |||
| 		PassUserHeaders:    opts.PassUserHeaders, | ||||
| 		BasicAuthPassword:  opts.BasicAuthPassword, | ||||
| 		PassAccessToken:    opts.PassAccessToken, | ||||
| 		SetAuthorization:   opts.SetAuthorization, | ||||
| 		PassAuthorization:  opts.PassAuthorization, | ||||
| 		SkipProviderButton: opts.SkipProviderButton, | ||||
| 		CookieCipher:       cipher, | ||||
| 		templates:          loadTemplates(opts.CustomTemplatesDir), | ||||
|  | @ -278,15 +286,100 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e | |||
| 
 | ||||
| // MakeSessionCookie creates an http.Cookie containing the authenticated user's
 | ||||
| // authentication details
 | ||||
| func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { | ||||
| func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { | ||||
| 	if value != "" { | ||||
| 		value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) | ||||
| 		if len(value) > 4096 { | ||||
| 			// Cookies cannot be larger than 4kb
 | ||||
| 			log.Printf("WARNING - Cookie Size: %d bytes", len(value)) | ||||
| 	} | ||||
| 	c := p.makeCookie(req, p.CookieName, value, expiration, now) | ||||
| 	if len(c.Value) > 4096-len(p.CookieName) { | ||||
| 		return splitCookie(c) | ||||
| 	} | ||||
| 	return []*http.Cookie{c} | ||||
| } | ||||
| 
 | ||||
| func copyCookie(c *http.Cookie) *http.Cookie { | ||||
| 	return &http.Cookie{ | ||||
| 		Name:       c.Name, | ||||
| 		Value:      c.Value, | ||||
| 		Path:       c.Path, | ||||
| 		Domain:     c.Domain, | ||||
| 		Expires:    c.Expires, | ||||
| 		RawExpires: c.RawExpires, | ||||
| 		MaxAge:     c.MaxAge, | ||||
| 		Secure:     c.Secure, | ||||
| 		HttpOnly:   c.HttpOnly, | ||||
| 		Raw:        c.Raw, | ||||
| 		Unparsed:   c.Unparsed, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // splitCookie reads the full cookie generated to store the session and splits
 | ||||
| // it into a slice of cookies which fit within the 4kb cookie limit indexing
 | ||||
| // the cookies from 0
 | ||||
| func splitCookie(c *http.Cookie) []*http.Cookie { | ||||
| 	if len(c.Value) < maxCookieLength { | ||||
| 		return []*http.Cookie{c} | ||||
| 	} | ||||
| 	cookies := []*http.Cookie{} | ||||
| 	valueBytes := []byte(c.Value) | ||||
| 	count := 0 | ||||
| 	for len(valueBytes) > 0 { | ||||
| 		new := copyCookie(c) | ||||
| 		new.Name = fmt.Sprintf("%s-%d", c.Name, count) | ||||
| 		count++ | ||||
| 		if len(valueBytes) < maxCookieLength { | ||||
| 			new.Value = string(valueBytes) | ||||
| 			valueBytes = []byte{} | ||||
| 		} else { | ||||
| 			newValue := valueBytes[:maxCookieLength] | ||||
| 			valueBytes = valueBytes[maxCookieLength:] | ||||
| 			new.Value = string(newValue) | ||||
| 		} | ||||
| 		cookies = append(cookies, new) | ||||
| 	} | ||||
| 	return cookies | ||||
| } | ||||
| 
 | ||||
| // joinCookies takes a slice of cookies from the request and reconstructs the
 | ||||
| // full session cookie
 | ||||
| func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { | ||||
| 	if len(cookies) == 0 { | ||||
| 		return nil, fmt.Errorf("list of cookies must be > 0") | ||||
| 	} | ||||
| 	if len(cookies) == 1 { | ||||
| 		return cookies[0], nil | ||||
| 	} | ||||
| 	c := copyCookie(cookies[0]) | ||||
| 	for i := 1; i < len(cookies); i++ { | ||||
| 		c.Value += cookies[i].Value | ||||
| 	} | ||||
| 	c.Name = strings.TrimRight(c.Name, "-0") | ||||
| 	return c, nil | ||||
| } | ||||
| 
 | ||||
| // loadCookie retreieves the sessions state cookie from the http request.
 | ||||
| // If a single cookie is present this will be returned, otherwise it attempts
 | ||||
| // to reconstruct a cookie split up by splitCookie
 | ||||
| func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { | ||||
| 	c, err := req.Cookie(cookieName) | ||||
| 	if err == nil { | ||||
| 		return c, nil | ||||
| 	} | ||||
| 	cookies := []*http.Cookie{} | ||||
| 	err = nil | ||||
| 	count := 0 | ||||
| 	for err == nil { | ||||
| 		var c *http.Cookie | ||||
| 		c, err = req.Cookie(fmt.Sprintf("%s-%d", cookieName, count)) | ||||
| 		if err == nil { | ||||
| 			cookies = append(cookies, c) | ||||
| 			count++ | ||||
| 		} | ||||
| 	} | ||||
| 	return p.makeCookie(req, p.CookieName, value, expiration, now) | ||||
| 	if len(cookies) == 0 { | ||||
| 		return nil, fmt.Errorf("Could not find cookie %s", cookieName) | ||||
| 	} | ||||
| 	return joinCookies(cookies) | ||||
| } | ||||
| 
 | ||||
| // MakeCSRFCookie creates a cookie for CSRF
 | ||||
|  | @ -330,12 +423,14 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va | |||
| // ClearSessionCookie creates a cookie to unset the user's authentication cookie
 | ||||
| // stored in the user's session
 | ||||
| func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { | ||||
| 	clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) | ||||
| 	http.SetCookie(rw, clr) | ||||
| 	cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) | ||||
| 	for _, clr := range cookies { | ||||
| 		http.SetCookie(rw, clr) | ||||
| 	} | ||||
| 
 | ||||
| 	// ugly hack because default domain changed
 | ||||
| 	if p.CookieDomain == "" { | ||||
| 		clr2 := *clr | ||||
| 	if p.CookieDomain == "" && len(cookies) > 0 { | ||||
| 		clr2 := *cookies[0] | ||||
| 		clr2.Domain = req.Host | ||||
| 		http.SetCookie(rw, &clr2) | ||||
| 	} | ||||
|  | @ -343,13 +438,15 @@ func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Reques | |||
| 
 | ||||
| // SetSessionCookie adds the user's session cookie to the response
 | ||||
| func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { | ||||
| 	http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) | ||||
| 	for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) { | ||||
| 		http.SetCookie(rw, c) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| // LoadCookiedSession reads the user's authentication details from the request
 | ||||
| func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { | ||||
| 	var age time.Duration | ||||
| 	c, err := req.Cookie(p.CookieName) | ||||
| 	c, err := loadCookie(req, p.CookieName) | ||||
| 	if err != nil { | ||||
| 		// always http.ErrNoCookie
 | ||||
| 		return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) | ||||
|  | @ -750,6 +847,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int | |||
| 	if p.PassAccessToken && session.AccessToken != "" { | ||||
| 		req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} | ||||
| 	} | ||||
| 	if p.PassAuthorization && session.IDToken != "" { | ||||
| 		req.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", session.IDToken)} | ||||
| 	} | ||||
| 	if p.SetAuthorization && session.IDToken != "" { | ||||
| 		rw.Header().Set("Authorization", fmt.Sprintf("Bearer %s", session.IDToken)) | ||||
| 	} | ||||
| 	if session.Email == "" { | ||||
| 		rw.Header().Set("GAP-Auth", session.User) | ||||
| 	} else { | ||||
|  |  | |||
|  | @ -502,7 +502,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { | |||
| 	}) | ||||
| } | ||||
| 
 | ||||
| func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie { | ||||
| func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { | ||||
| 	return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) | ||||
| } | ||||
| 
 | ||||
|  | @ -511,7 +511,9 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time | |||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref)) | ||||
| 	for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { | ||||
| 		p.req.AddCookie(c) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
|  | @ -800,8 +802,9 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { | |||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) | ||||
| 	req.AddCookie(cookie) | ||||
| 	for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) { | ||||
| 		req.AddCookie(c) | ||||
| 	} | ||||
| 	// This is used by the upstream to validate the signature.
 | ||||
| 	st.authenticator.auth = hmacauth.NewHmacAuth( | ||||
| 		crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) | ||||
|  |  | |||
|  | @ -61,6 +61,8 @@ type Options struct { | |||
| 	PassUserHeaders       bool     `flag:"pass-user-headers" cfg:"pass_user_headers"` | ||||
| 	SSLInsecureSkipVerify bool     `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` | ||||
| 	SetXAuthRequest       bool     `flag:"set-xauthrequest" cfg:"set_xauthrequest"` | ||||
| 	SetAuthorization      bool     `flag:"set-authorization-header" cfg:"set_authorization_header"` | ||||
| 	PassAuthorization     bool     `flag:"pass-authorization-header" cfg:"pass_authorization_header"` | ||||
| 	SkipAuthPreflight     bool     `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` | ||||
| 
 | ||||
| 	// These options allow for other providers besides Google, with
 | ||||
|  | @ -113,6 +115,8 @@ func NewOptions() *Options { | |||
| 		PassUserHeaders:      true, | ||||
| 		PassAccessToken:      false, | ||||
| 		PassHostHeader:       true, | ||||
| 		SetAuthorization:     false, | ||||
| 		PassAuthorization:    false, | ||||
| 		ApprovalPrompt:       "force", | ||||
| 		RequestLogging:       true, | ||||
| 		RequestLoggingFormat: defaultRequestLoggingFormat, | ||||
|  |  | |||
|  | @ -145,6 +145,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err | |||
| 	} | ||||
| 	s = &SessionState{ | ||||
| 		AccessToken:  jsonResponse.AccessToken, | ||||
| 		IDToken:      jsonResponse.IDToken, | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), | ||||
| 		RefreshToken: jsonResponse.RefreshToken, | ||||
| 		Email:        email, | ||||
|  |  | |||
|  | @ -38,7 +38,61 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er | |||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("token exchange: %v", err) | ||||
| 	} | ||||
| 	s, err = p.createSessionState(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("unable to update session: %v", err) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
| // RefreshToken to fetch a new ID token if required
 | ||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||
| 		return false, nil | ||||
| 	} | ||||
| 
 | ||||
| 	origExpiration := s.ExpiresOn | ||||
| 
 | ||||
| 	err := p.redeemRefreshToken(s) | ||||
| 	if err != nil { | ||||
| 		return false, fmt.Errorf("unable to redeem refresh token: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration) | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { | ||||
| 	c := oauth2.Config{ | ||||
| 		ClientID:     p.ClientID, | ||||
| 		ClientSecret: p.ClientSecret, | ||||
| 		Endpoint: oauth2.Endpoint{ | ||||
| 			TokenURL: p.RedeemURL.String(), | ||||
| 		}, | ||||
| 	} | ||||
| 	ctx := context.Background() | ||||
| 	t := &oauth2.Token{ | ||||
| 		RefreshToken: s.RefreshToken, | ||||
| 		Expiry:       time.Now().Add(-time.Hour), | ||||
| 	} | ||||
| 	token, err := c.TokenSource(ctx, t).Token() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("failed to get token: %v", err) | ||||
| 	} | ||||
| 	newSession, err := p.createSessionState(ctx, token) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("unable to update session: %v", err) | ||||
| 	} | ||||
| 	s.AccessToken = newSession.AccessToken | ||||
| 	s.IDToken = newSession.IDToken | ||||
| 	s.RefreshToken = newSession.RefreshToken | ||||
| 	s.ExpiresOn = newSession.ExpiresOn | ||||
| 	s.Email = newSession.Email | ||||
| 	return | ||||
| } | ||||
| 
 | ||||
| func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { | ||||
| 	rawIDToken, ok := token.Extra("id_token").(string) | ||||
| 	if !ok { | ||||
| 		return nil, fmt.Errorf("token response did not contain an id_token") | ||||
|  | @ -66,28 +120,22 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er | |||
| 		return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) | ||||
| 	} | ||||
| 
 | ||||
| 	s = &SessionState{ | ||||
| 	return &SessionState{ | ||||
| 		AccessToken:  token.AccessToken, | ||||
| 		IDToken:      rawIDToken, | ||||
| 		RefreshToken: token.RefreshToken, | ||||
| 		ExpiresOn:    token.Expiry, | ||||
| 		Email:        claims.Email, | ||||
| 	} | ||||
| 
 | ||||
| 	return | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| // RefreshSessionIfNeeded checks if the session has expired and uses the
 | ||||
| // RefreshToken to fetch a new ID token if required
 | ||||
| //
 | ||||
| // WARNGING: This implementation is broken and does not check with the upstream
 | ||||
| // OIDC provider before refreshing the session
 | ||||
| func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { | ||||
| 	if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { | ||||
| 		return false, nil | ||||
| // ValidateSessionState checks that the session's IDToken is still valid
 | ||||
| func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { | ||||
| 	ctx := context.Background() | ||||
| 	_, err := p.Verifier.Verify(ctx, s.IDToken) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
| 
 | ||||
| 	origExpiration := s.ExpiresOn | ||||
| 	s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second) | ||||
| 	fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration) | ||||
| 	return false, nil | ||||
| 	return true | ||||
| } | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ import ( | |||
| // SessionState is used to store information about the currently authenticated user session
 | ||||
| type SessionState struct { | ||||
| 	AccessToken  string | ||||
| 	IDToken      string | ||||
| 	ExpiresOn    time.Time | ||||
| 	RefreshToken string | ||||
| 	Email        string | ||||
|  | @ -32,6 +33,9 @@ func (s *SessionState) String() string { | |||
| 	if s.AccessToken != "" { | ||||
| 		o += " token:true" | ||||
| 	} | ||||
| 	if s.IDToken != "" { | ||||
| 		o += " id_token:true" | ||||
| 	} | ||||
| 	if !s.ExpiresOn.IsZero() { | ||||
| 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | ||||
| 	} | ||||
|  | @ -65,13 +69,19 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { | |||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	i := s.IDToken | ||||
| 	if i != "" { | ||||
| 		if i, err = c.Encrypt(i); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	r := s.RefreshToken | ||||
| 	if r != "" { | ||||
| 		if r, err = c.Encrypt(r); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
| 	} | ||||
| 	return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil | ||||
| 	return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil | ||||
| } | ||||
| 
 | ||||
| func decodeSessionStatePlain(v string) (s *SessionState, err error) { | ||||
|  | @ -96,8 +106,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) | |||
| 	} | ||||
| 
 | ||||
| 	chunks := strings.Split(v, "|") | ||||
| 	if len(chunks) != 4 { | ||||
| 		err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) | ||||
| 	if len(chunks) != 5 { | ||||
| 		err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
|  | @ -112,11 +122,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) | |||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	ts, _ := strconv.Atoi(chunks[2]) | ||||
| 	if chunks[2] != "" { | ||||
| 		if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	ts, _ := strconv.Atoi(chunks[3]) | ||||
| 	sessionState.ExpiresOn = time.Unix(int64(ts), 0) | ||||
| 
 | ||||
| 	if chunks[3] != "" { | ||||
| 		if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { | ||||
| 	if chunks[4] != "" { | ||||
| 		if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  |  | |||
|  | @ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	s := &SessionState{ | ||||
| 		Email:        "user@domain.com", | ||||
| 		AccessToken:  "token1234", | ||||
| 		IDToken:      "rawtoken1234", | ||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||
| 		RefreshToken: "refresh4321", | ||||
| 	} | ||||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, 3, strings.Count(encoded, "|")) | ||||
| 	assert.Equal(t, 4, strings.Count(encoded, "|")) | ||||
| 
 | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
|  | @ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	assert.Equal(t, "user", ss.User) | ||||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.Equal(t, s.IDToken, ss.IDToken) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||
| 
 | ||||
|  | @ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) { | |||
| 	assert.Equal(t, s.Email, ss.Email) | ||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||
| 	assert.NotEqual(t, s.IDToken, ss.IDToken) | ||||
| 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | ||||
| } | ||||
| 
 | ||||
|  | @ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | |||
| 	} | ||||
| 	encoded, err := s.EncodeSessionState(c) | ||||
| 	assert.Equal(t, nil, err) | ||||
| 	assert.Equal(t, 3, strings.Count(encoded, "|")) | ||||
| 	assert.Equal(t, 4, strings.Count(encoded, "|")) | ||||
| 
 | ||||
| 	ss, err := DecodeSessionState(encoded, c) | ||||
| 	t.Logf("%#v", ss) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue