153 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			153 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
| package sessions
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"time"
 | |
| 	"unicode/utf8"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
 | |
| )
 | |
| 
 | |
| // SessionState is used to store information about the currently authenticated user session
 | |
| type SessionState struct {
 | |
| 	AccessToken       string     `json:",omitempty"`
 | |
| 	IDToken           string     `json:",omitempty"`
 | |
| 	CreatedAt         *time.Time `json:",omitempty"`
 | |
| 	ExpiresOn         *time.Time `json:",omitempty"`
 | |
| 	RefreshToken      string     `json:",omitempty"`
 | |
| 	Email             string     `json:",omitempty"`
 | |
| 	User              string     `json:",omitempty"`
 | |
| 	PreferredUsername string     `json:",omitempty"`
 | |
| }
 | |
| 
 | |
| // IsExpired checks whether the session has expired
 | |
| func (s *SessionState) IsExpired() bool {
 | |
| 	if s.ExpiresOn != nil && !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
 | |
| 		return true
 | |
| 	}
 | |
| 	return false
 | |
| }
 | |
| 
 | |
| // Age returns the age of a session
 | |
| func (s *SessionState) Age() time.Duration {
 | |
| 	if s.CreatedAt != nil && !s.CreatedAt.IsZero() {
 | |
| 		return time.Now().Truncate(time.Second).Sub(*s.CreatedAt)
 | |
| 	}
 | |
| 	return 0
 | |
| }
 | |
| 
 | |
| // String constructs a summary of the session state
 | |
| func (s *SessionState) String() string {
 | |
| 	o := fmt.Sprintf("Session{email:%s user:%s PreferredUsername:%s", s.Email, s.User, s.PreferredUsername)
 | |
| 	if s.AccessToken != "" {
 | |
| 		o += " token:true"
 | |
| 	}
 | |
| 	if s.IDToken != "" {
 | |
| 		o += " id_token:true"
 | |
| 	}
 | |
| 	if !s.CreatedAt.IsZero() {
 | |
| 		o += fmt.Sprintf(" created:%s", s.CreatedAt)
 | |
| 	}
 | |
| 	if !s.ExpiresOn.IsZero() {
 | |
| 		o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
 | |
| 	}
 | |
| 	if s.RefreshToken != "" {
 | |
| 		o += " refresh_token:true"
 | |
| 	}
 | |
| 	return o + "}"
 | |
| }
 | |
| 
 | |
| // EncodeSessionState returns string representation of the current session
 | |
| func (s *SessionState) EncodeSessionState(c encryption.Cipher) (string, error) {
 | |
| 	var ss SessionState
 | |
| 	if c == nil {
 | |
| 		// Store only Email and User when cipher is unavailable
 | |
| 		ss.Email = s.Email
 | |
| 		ss.User = s.User
 | |
| 		ss.PreferredUsername = s.PreferredUsername
 | |
| 	} else {
 | |
| 		ss = *s
 | |
| 		for _, s := range []*string{
 | |
| 			&ss.Email,
 | |
| 			&ss.User,
 | |
| 			&ss.PreferredUsername,
 | |
| 			&ss.AccessToken,
 | |
| 			&ss.IDToken,
 | |
| 			&ss.RefreshToken,
 | |
| 		} {
 | |
| 			err := into(s, c.Encrypt)
 | |
| 			if err != nil {
 | |
| 				return "", err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	b, err := json.Marshal(ss)
 | |
| 	return string(b), err
 | |
| }
 | |
| 
 | |
| // DecodeSessionState decodes the session cookie string into a SessionState
 | |
| func DecodeSessionState(v string, c encryption.Cipher) (*SessionState, error) {
 | |
| 	var ss SessionState
 | |
| 	err := json.Unmarshal([]byte(v), &ss)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("error unmarshalling session: %w", err)
 | |
| 	}
 | |
| 
 | |
| 	if c == nil {
 | |
| 		// Load only Email and User when cipher is unavailable
 | |
| 		ss = SessionState{
 | |
| 			Email:             ss.Email,
 | |
| 			User:              ss.User,
 | |
| 			PreferredUsername: ss.PreferredUsername,
 | |
| 		}
 | |
| 	} else {
 | |
| 		// Backward compatibility with using unencrypted Email or User
 | |
| 		// Decryption errors will leave original string
 | |
| 		err = into(&ss.Email, c.Decrypt)
 | |
| 		if err == nil {
 | |
| 			if !utf8.ValidString(ss.Email) {
 | |
| 				return nil, errors.New("invalid value for decrypted email")
 | |
| 			}
 | |
| 		}
 | |
| 		err = into(&ss.User, c.Decrypt)
 | |
| 		if err == nil {
 | |
| 			if !utf8.ValidString(ss.User) {
 | |
| 				return nil, errors.New("invalid value for decrypted user")
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		for _, s := range []*string{
 | |
| 			&ss.PreferredUsername,
 | |
| 			&ss.AccessToken,
 | |
| 			&ss.IDToken,
 | |
| 			&ss.RefreshToken,
 | |
| 		} {
 | |
| 			err := into(s, c.Decrypt)
 | |
| 			if err != nil {
 | |
| 				return nil, err
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return &ss, nil
 | |
| }
 | |
| 
 | |
| // codecFunc is a function that takes a []byte and encodes/decodes it
 | |
| type codecFunc func([]byte) ([]byte, error)
 | |
| 
 | |
| func into(s *string, f codecFunc) error {
 | |
| 	// Do not encrypt/decrypt nil or empty strings
 | |
| 	if s == nil || *s == "" {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	d, err := f([]byte(*s))
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	*s = string(d)
 | |
| 	return nil
 | |
| }
 |