244 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			244 lines
		
	
	
		
			5.9 KiB
		
	
	
	
		
			Go
		
	
	
	
package sessions
 | 
						|
 | 
						|
import (
 | 
						|
	"encoding/json"
 | 
						|
	"fmt"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/pusher/oauth2_proxy/cookie"
 | 
						|
)
 | 
						|
 | 
						|
// SessionState is used to store information about the currently authenticated user session
 | 
						|
type SessionState struct {
 | 
						|
	AccessToken  string    `json:",omitempty"`
 | 
						|
	IDToken      string    `json:",omitempty"`
 | 
						|
	CreatedAt    time.Time `json:"-"`
 | 
						|
	ExpiresOn    time.Time `json:"-"`
 | 
						|
	RefreshToken string    `json:",omitempty"`
 | 
						|
	Email        string    `json:",omitempty"`
 | 
						|
	User         string    `json:",omitempty"`
 | 
						|
}
 | 
						|
 | 
						|
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
 | 
						|
type SessionStateJSON struct {
 | 
						|
	*SessionState
 | 
						|
	CreatedAt *time.Time `json:",omitempty"`
 | 
						|
	ExpiresOn *time.Time `json:",omitempty"`
 | 
						|
}
 | 
						|
 | 
						|
// IsExpired checks whether the session has expired
 | 
						|
func (s *SessionState) IsExpired() bool {
 | 
						|
	if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
 | 
						|
		return true
 | 
						|
	}
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
// Age returns the age of a session
 | 
						|
func (s *SessionState) Age() time.Duration {
 | 
						|
	if !s.CreatedAt.IsZero() {
 | 
						|
		return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
 | 
						|
	}
 | 
						|
	return 0
 | 
						|
}
 | 
						|
 | 
						|
// String constructs a summary of the session state
 | 
						|
func (s *SessionState) String() string {
 | 
						|
	o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
 | 
						|
	if s.AccessToken != "" {
 | 
						|
		o += " token:true"
 | 
						|
	}
 | 
						|
	if s.IDToken != "" {
 | 
						|
		o += " id_token:true"
 | 
						|
	}
 | 
						|
	if !s.CreatedAt.IsZero() {
 | 
						|
		o += fmt.Sprintf(" created:%s", s.CreatedAt)
 | 
						|
	}
 | 
						|
	if !s.ExpiresOn.IsZero() {
 | 
						|
		o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
 | 
						|
	}
 | 
						|
	if s.RefreshToken != "" {
 | 
						|
		o += " refresh_token:true"
 | 
						|
	}
 | 
						|
	return o + "}"
 | 
						|
}
 | 
						|
 | 
						|
// EncodeSessionState returns string representation of the current session
 | 
						|
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
 | 
						|
	var ss SessionState
 | 
						|
	if c == nil {
 | 
						|
		// Store only Email and User when cipher is unavailable
 | 
						|
		ss.Email = s.Email
 | 
						|
		ss.User = s.User
 | 
						|
	} else {
 | 
						|
		ss = *s
 | 
						|
		var err error
 | 
						|
		if ss.Email != "" {
 | 
						|
			ss.Email, err = c.Encrypt(ss.Email)
 | 
						|
			if err != nil {
 | 
						|
				return "", err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.User != "" {
 | 
						|
			ss.User, err = c.Encrypt(ss.User)
 | 
						|
			if err != nil {
 | 
						|
				return "", err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.AccessToken != "" {
 | 
						|
			ss.AccessToken, err = c.Encrypt(ss.AccessToken)
 | 
						|
			if err != nil {
 | 
						|
				return "", err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.IDToken != "" {
 | 
						|
			ss.IDToken, err = c.Encrypt(ss.IDToken)
 | 
						|
			if err != nil {
 | 
						|
				return "", err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.RefreshToken != "" {
 | 
						|
			ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
 | 
						|
			if err != nil {
 | 
						|
				return "", err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	// Embed SessionState and ExpiresOn pointer into SessionStateJSON
 | 
						|
	ssj := &SessionStateJSON{SessionState: &ss}
 | 
						|
	if !ss.CreatedAt.IsZero() {
 | 
						|
		ssj.CreatedAt = &ss.CreatedAt
 | 
						|
	}
 | 
						|
	if !ss.ExpiresOn.IsZero() {
 | 
						|
		ssj.ExpiresOn = &ss.ExpiresOn
 | 
						|
	}
 | 
						|
	b, err := json.Marshal(ssj)
 | 
						|
	return string(b), err
 | 
						|
}
 | 
						|
 | 
						|
// legacyDecodeSessionStatePlain decodes older plain session state string
 | 
						|
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
 | 
						|
	chunks := strings.Split(v, " ")
 | 
						|
	if len(chunks) != 2 {
 | 
						|
		return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
 | 
						|
	}
 | 
						|
 | 
						|
	user := strings.TrimPrefix(chunks[1], "user:")
 | 
						|
	email := strings.TrimPrefix(chunks[0], "email:")
 | 
						|
 | 
						|
	return &SessionState{User: user, Email: email}, nil
 | 
						|
}
 | 
						|
 | 
						|
// legacyDecodeSessionState attempts to decode the session state string
 | 
						|
// generated by v3.1.0 or older
 | 
						|
func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
 | 
						|
	chunks := strings.Split(v, "|")
 | 
						|
 | 
						|
	if c == nil {
 | 
						|
		if len(chunks) != 1 {
 | 
						|
			return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
 | 
						|
		}
 | 
						|
		return legacyDecodeSessionStatePlain(chunks[0])
 | 
						|
	}
 | 
						|
 | 
						|
	if len(chunks) != 4 && len(chunks) != 5 {
 | 
						|
		return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
 | 
						|
	}
 | 
						|
 | 
						|
	i := 0
 | 
						|
	ss, err := legacyDecodeSessionStatePlain(chunks[i])
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	i++
 | 
						|
	ss.AccessToken = chunks[i]
 | 
						|
 | 
						|
	if len(chunks) == 5 {
 | 
						|
		// SessionState with IDToken in v3.1.0
 | 
						|
		i++
 | 
						|
		ss.IDToken = chunks[i]
 | 
						|
	}
 | 
						|
 | 
						|
	i++
 | 
						|
	ts, err := strconv.Atoi(chunks[i])
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
 | 
						|
	}
 | 
						|
	ss.ExpiresOn = time.Unix(int64(ts), 0)
 | 
						|
 | 
						|
	i++
 | 
						|
	ss.RefreshToken = chunks[i]
 | 
						|
 | 
						|
	return ss, nil
 | 
						|
}
 | 
						|
 | 
						|
// DecodeSessionState decodes the session cookie string into a SessionState
 | 
						|
func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
 | 
						|
	var ssj SessionStateJSON
 | 
						|
	var ss *SessionState
 | 
						|
	err := json.Unmarshal([]byte(v), &ssj)
 | 
						|
	if err == nil && ssj.SessionState != nil {
 | 
						|
		// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
 | 
						|
		ss = ssj.SessionState
 | 
						|
		if ssj.CreatedAt != nil {
 | 
						|
			ss.CreatedAt = *ssj.CreatedAt
 | 
						|
		}
 | 
						|
		if ssj.ExpiresOn != nil {
 | 
						|
			ss.ExpiresOn = *ssj.ExpiresOn
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		// Try to decode a legacy string when json.Unmarshal failed
 | 
						|
		ss, err = legacyDecodeSessionState(v, c)
 | 
						|
		if err != nil {
 | 
						|
			return nil, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if c == nil {
 | 
						|
		// Load only Email and User when cipher is unavailable
 | 
						|
		ss = &SessionState{
 | 
						|
			Email: ss.Email,
 | 
						|
			User:  ss.User,
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		// Backward compatibility with using unecrypted Email
 | 
						|
		if ss.Email != "" {
 | 
						|
			decryptedEmail, errEmail := c.Decrypt(ss.Email)
 | 
						|
			if errEmail == nil {
 | 
						|
				ss.Email = decryptedEmail
 | 
						|
			}
 | 
						|
		}
 | 
						|
		// Backward compatibility with using unecrypted User
 | 
						|
		if ss.User != "" {
 | 
						|
			decryptedUser, errUser := c.Decrypt(ss.User)
 | 
						|
			if errUser == nil {
 | 
						|
				ss.User = decryptedUser
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.AccessToken != "" {
 | 
						|
			ss.AccessToken, err = c.Decrypt(ss.AccessToken)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.IDToken != "" {
 | 
						|
			ss.IDToken, err = c.Decrypt(ss.IDToken)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if ss.RefreshToken != "" {
 | 
						|
			ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
 | 
						|
			if err != nil {
 | 
						|
				return nil, err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if ss.User == "" {
 | 
						|
		ss.User = ss.Email
 | 
						|
	}
 | 
						|
	return ss, nil
 | 
						|
}
 |