Update sessions state
This commit is contained in:
		
							parent
							
								
									68d4164897
								
							
						
					
					
						commit
						6aa35a9ecf
					
				|  | @ -33,6 +33,9 @@ func (s *SessionState) String() string { | ||||||
| 	if s.AccessToken != "" { | 	if s.AccessToken != "" { | ||||||
| 		o += " token:true" | 		o += " token:true" | ||||||
| 	} | 	} | ||||||
|  | 	if s.IDToken != "" { | ||||||
|  | 		o += " id_token:true" | ||||||
|  | 	} | ||||||
| 	if !s.ExpiresOn.IsZero() { | 	if !s.ExpiresOn.IsZero() { | ||||||
| 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn) | ||||||
| 	} | 	} | ||||||
|  | @ -66,13 +69,19 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	i := s.IDToken | ||||||
|  | 	if i != "" { | ||||||
|  | 		if i, err = c.Encrypt(i); err != nil { | ||||||
|  | 			return "", err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
| 	r := s.RefreshToken | 	r := s.RefreshToken | ||||||
| 	if r != "" { | 	if r != "" { | ||||||
| 		if r, err = c.Encrypt(r); err != nil { | 		if r, err = c.Encrypt(r); err != nil { | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil | 	return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func decodeSessionStatePlain(v string) (s *SessionState, err error) { | func decodeSessionStatePlain(v string) (s *SessionState, err error) { | ||||||
|  | @ -97,8 +106,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	chunks := strings.Split(v, "|") | 	chunks := strings.Split(v, "|") | ||||||
| 	if len(chunks) != 4 { | 	if len(chunks) != 5 { | ||||||
| 		err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) | 		err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | @ -113,11 +122,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ts, _ := strconv.Atoi(chunks[2]) | 	if chunks[2] != "" { | ||||||
|  | 		if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	ts, _ := strconv.Atoi(chunks[3]) | ||||||
| 	sessionState.ExpiresOn = time.Unix(int64(ts), 0) | 	sessionState.ExpiresOn = time.Unix(int64(ts), 0) | ||||||
| 
 | 
 | ||||||
| 	if chunks[3] != "" { | 	if chunks[4] != "" { | ||||||
| 		if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { | 		if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | @ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	s := &SessionState{ | 	s := &SessionState{ | ||||||
| 		Email:        "user@domain.com", | 		Email:        "user@domain.com", | ||||||
| 		AccessToken:  "token1234", | 		AccessToken:  "token1234", | ||||||
|  | 		IDToken:      "rawtoken1234", | ||||||
| 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | 		ExpiresOn:    time.Now().Add(time.Duration(1) * time.Hour), | ||||||
| 		RefreshToken: "refresh4321", | 		RefreshToken: "refresh4321", | ||||||
| 	} | 	} | ||||||
| 	encoded, err := s.EncodeSessionState(c) | 	encoded, err := s.EncodeSessionState(c) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, 3, strings.Count(encoded, "|")) | 	assert.Equal(t, 4, strings.Count(encoded, "|")) | ||||||
| 
 | 
 | ||||||
| 	ss, err := DecodeSessionState(encoded, c) | 	ss, err := DecodeSessionState(encoded, c) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
|  | @ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, "user", ss.User) | 	assert.Equal(t, "user", ss.User) | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
| 	assert.Equal(t, s.AccessToken, ss.AccessToken) | 	assert.Equal(t, s.AccessToken, ss.AccessToken) | ||||||
|  | 	assert.Equal(t, s.IDToken, ss.IDToken) | ||||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||||
| 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | 	assert.Equal(t, s.RefreshToken, ss.RefreshToken) | ||||||
| 
 | 
 | ||||||
|  | @ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) { | ||||||
| 	assert.Equal(t, s.Email, ss.Email) | 	assert.Equal(t, s.Email, ss.Email) | ||||||
| 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | 	assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) | ||||||
| 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | 	assert.NotEqual(t, s.AccessToken, ss.AccessToken) | ||||||
|  | 	assert.NotEqual(t, s.IDToken, ss.IDToken) | ||||||
| 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | 	assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { | ||||||
| 	} | 	} | ||||||
| 	encoded, err := s.EncodeSessionState(c) | 	encoded, err := s.EncodeSessionState(c) | ||||||
| 	assert.Equal(t, nil, err) | 	assert.Equal(t, nil, err) | ||||||
| 	assert.Equal(t, 3, strings.Count(encoded, "|")) | 	assert.Equal(t, 4, strings.Count(encoded, "|")) | ||||||
| 
 | 
 | ||||||
| 	ss, err := DecodeSessionState(encoded, c) | 	ss, err := DecodeSessionState(encoded, c) | ||||||
| 	t.Logf("%#v", ss) | 	t.Logf("%#v", ss) | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue