311 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			311 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Go
		
	
	
	
package cookies
 | 
						|
 | 
						|
import (
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"net/http"
 | 
						|
	"slices"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
 | 
						|
	"github.com/vmihailenco/msgpack/v5"
 | 
						|
)
 | 
						|
 | 
						|
// CSRF manages various nonces stored in the CSRF cookie during the initial
 | 
						|
// authentication flows.
 | 
						|
type CSRF interface {
 | 
						|
	HashOAuthState() string
 | 
						|
	HashOIDCNonce() string
 | 
						|
	CheckOAuthState(string) bool
 | 
						|
	CheckOIDCNonce(string) bool
 | 
						|
	GetCodeVerifier() string
 | 
						|
 | 
						|
	SetSessionNonce(s *sessions.SessionState)
 | 
						|
 | 
						|
	SetCookie(http.ResponseWriter, *http.Request) (*http.Cookie, error)
 | 
						|
	ClearCookie(http.ResponseWriter, *http.Request)
 | 
						|
}
 | 
						|
 | 
						|
type csrf struct {
 | 
						|
	// OAuthState holds the OAuth2 state parameter's nonce component set in the
 | 
						|
	// initial authentication request and mirrored back in the callback
 | 
						|
	// redirect from the IdP for CSRF protection.
 | 
						|
	OAuthState []byte `msgpack:"s,omitempty"`
 | 
						|
 | 
						|
	// OIDCNonce holds the OIDC nonce parameter used in the initial authentication
 | 
						|
	// and then set in all subsequent OIDC ID Tokens as the nonce claim. This
 | 
						|
	// is used to mitigate replay attacks.
 | 
						|
	OIDCNonce []byte `msgpack:"n,omitempty"`
 | 
						|
 | 
						|
	// CodeVerifier holds the unobfuscated PKCE code verification string
 | 
						|
	// which is used to compare the code challenge when exchanging the
 | 
						|
	// authentication code.
 | 
						|
	CodeVerifier string `msgpack:"cv,omitempty"`
 | 
						|
 | 
						|
	cookieOpts *options.Cookie
 | 
						|
	clock      func() time.Time
 | 
						|
}
 | 
						|
 | 
						|
// csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
 | 
						|
const csrfStateLength int = 9
 | 
						|
 | 
						|
// NewCSRF creates a CSRF with random nonces
 | 
						|
func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) {
 | 
						|
	state, err := encryption.Nonce(32)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	nonce, err := encryption.Nonce(32)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return &csrf{
 | 
						|
		OAuthState:   state,
 | 
						|
		OIDCNonce:    nonce,
 | 
						|
		CodeVerifier: codeVerifier,
 | 
						|
 | 
						|
		cookieOpts: opts,
 | 
						|
		clock:      time.Now,
 | 
						|
	}, nil
 | 
						|
}
 | 
						|
 | 
						|
// LoadCSRFCookie loads a CSRF object from a request's CSRF cookie
 | 
						|
func LoadCSRFCookie(req *http.Request, cookieName string, opts *options.Cookie) (CSRF, error) {
 | 
						|
	cookies := req.Cookies()
 | 
						|
	for _, cookie := range cookies {
 | 
						|
		if cookie.Name != cookieName {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		csrf, err := decodeCSRFCookie(cookie, opts)
 | 
						|
		if err != nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		return csrf, nil
 | 
						|
	}
 | 
						|
 | 
						|
	return nil, fmt.Errorf("CSRF cookie with name '%v' was not found", cookieName)
 | 
						|
}
 | 
						|
 | 
						|
// GenerateCookieName in case cookie options state that CSRF cookie has fixed name then set fixed name, otherwise
 | 
						|
// build name based on the state
 | 
						|
func GenerateCookieName(opts *options.Cookie, state string) string {
 | 
						|
	stateSubstring := ""
 | 
						|
	if opts.CSRFPerRequest {
 | 
						|
		// csrfCookieName will include a substring of the state to enable multiple csrf cookies
 | 
						|
		// in case of parallel requests
 | 
						|
		stateSubstring = ExtractStateSubstring(state)
 | 
						|
	}
 | 
						|
	return csrfCookieName(opts, stateSubstring)
 | 
						|
}
 | 
						|
 | 
						|
func (c *csrf) GetCodeVerifier() string {
 | 
						|
	return c.CodeVerifier
 | 
						|
}
 | 
						|
 | 
						|
// HashOAuthState returns the hash of the OAuth state nonce
 | 
						|
func (c *csrf) HashOAuthState() string {
 | 
						|
	return encryption.HashNonce(c.OAuthState)
 | 
						|
}
 | 
						|
 | 
						|
// HashOIDCNonce returns the hash of the OIDC nonce
 | 
						|
func (c *csrf) HashOIDCNonce() string {
 | 
						|
	return encryption.HashNonce(c.OIDCNonce)
 | 
						|
}
 | 
						|
 | 
						|
// CheckOAuthState compares the OAuth state nonce against a potential
 | 
						|
// hash of it
 | 
						|
func (c *csrf) CheckOAuthState(hashed string) bool {
 | 
						|
	return encryption.CheckNonce(c.OAuthState, hashed)
 | 
						|
}
 | 
						|
 | 
						|
// CheckOIDCNonce compares the OIDC nonce against a potential hash of it
 | 
						|
func (c *csrf) CheckOIDCNonce(hashed string) bool {
 | 
						|
	return encryption.CheckNonce(c.OIDCNonce, hashed)
 | 
						|
}
 | 
						|
 | 
						|
// SetSessionNonce sets the OIDCNonce on a SessionState
 | 
						|
func (c *csrf) SetSessionNonce(s *sessions.SessionState) {
 | 
						|
	s.Nonce = c.OIDCNonce
 | 
						|
}
 | 
						|
 | 
						|
// SetCookie encodes the CSRF to a signed cookie and sets it on the ResponseWriter
 | 
						|
func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cookie, error) {
 | 
						|
	encoded, err := c.encodeCookie()
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	cookie := MakeCookieFromOptions(
 | 
						|
		req,
 | 
						|
		c.cookieName(),
 | 
						|
		encoded,
 | 
						|
		c.cookieOpts,
 | 
						|
		c.cookieOpts.CSRFExpire,
 | 
						|
	)
 | 
						|
	http.SetCookie(rw, cookie)
 | 
						|
 | 
						|
	return cookie, nil
 | 
						|
}
 | 
						|
 | 
						|
// ClearExtraCsrfCookies limits the amount of existing CSRF cookies by deleting
 | 
						|
// an excess of cookies controlled through the option CSRFPerRequestLimit
 | 
						|
func ClearExtraCsrfCookies(opts *options.Cookie, rw http.ResponseWriter, req *http.Request) {
 | 
						|
	if !opts.CSRFPerRequest || opts.CSRFPerRequestLimit <= 0 {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	cookies := req.Cookies()
 | 
						|
	existingCsrfCookies := []*http.Cookie{}
 | 
						|
	startsWith := fmt.Sprintf("%v_", opts.Name)
 | 
						|
 | 
						|
	// determine how many csrf cookies we have
 | 
						|
	for _, cookie := range cookies {
 | 
						|
		if strings.HasPrefix(cookie.Name, startsWith) && strings.HasSuffix(cookie.Name, "_csrf") {
 | 
						|
			existingCsrfCookies = append(existingCsrfCookies, cookie)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// short circuit return
 | 
						|
	if len(existingCsrfCookies) <= opts.CSRFPerRequestLimit {
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	decodedCookies := []*csrf{}
 | 
						|
	for _, cookie := range existingCsrfCookies {
 | 
						|
		decodedCookie, err := decodeCSRFCookie(cookie, opts)
 | 
						|
		if err != nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		decodedCookies = append(decodedCookies, decodedCookie)
 | 
						|
	}
 | 
						|
 | 
						|
	// delete the X oldest cookies
 | 
						|
	slices.SortStableFunc(decodedCookies, func(a, b *csrf) int {
 | 
						|
		return a.clock().Compare(b.clock())
 | 
						|
	})
 | 
						|
 | 
						|
	for i := 0; i < len(decodedCookies)-opts.CSRFPerRequestLimit; i++ {
 | 
						|
		decodedCookies[i].ClearCookie(rw, req)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// ClearCookie removes the CSRF cookie
 | 
						|
func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) {
 | 
						|
	http.SetCookie(rw, MakeCookieFromOptions(
 | 
						|
		req,
 | 
						|
		c.cookieName(),
 | 
						|
		"",
 | 
						|
		c.cookieOpts,
 | 
						|
		time.Hour*-1,
 | 
						|
	))
 | 
						|
}
 | 
						|
 | 
						|
// encodeCookie MessagePack encodes and encrypts the CSRF and then creates a
 | 
						|
// signed cookie value
 | 
						|
func (c *csrf) encodeCookie() (string, error) {
 | 
						|
	packed, err := msgpack.Marshal(c)
 | 
						|
	if err != nil {
 | 
						|
		return "", fmt.Errorf("error marshalling CSRF to msgpack: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	encrypted, err := encrypt(packed, c.cookieOpts)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	secret, err := c.cookieOpts.GetSecret()
 | 
						|
	if err != nil {
 | 
						|
		return "", fmt.Errorf("error getting cookie secret: %v", err)
 | 
						|
	}
 | 
						|
	return encryption.SignedValue(secret, c.cookieName(), encrypted, c.clock())
 | 
						|
}
 | 
						|
 | 
						|
// decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
 | 
						|
// cookie into a CSRF struct
 | 
						|
func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) {
 | 
						|
	secret, err := opts.GetSecret()
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("error getting cookie secret: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	val, t, ok := encryption.Validate(cookie, secret, opts.Expire)
 | 
						|
	if !ok {
 | 
						|
		return nil, errors.New("CSRF cookie failed validation")
 | 
						|
	}
 | 
						|
 | 
						|
	decrypted, err := decrypt(val, opts)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
 | 
						|
	return unmarshalCSRF(decrypted, opts, t)
 | 
						|
}
 | 
						|
 | 
						|
// unmarshalCSRF unmarshals decrypted data into a CSRF struct
 | 
						|
func unmarshalCSRF(decrypted []byte, opts *options.Cookie, csrfTime time.Time) (*csrf, error) {
 | 
						|
	csrf := &csrf{
 | 
						|
		cookieOpts: opts,
 | 
						|
		clock:      func() time.Time { return csrfTime },
 | 
						|
	}
 | 
						|
	if err := msgpack.Unmarshal(decrypted, csrf); err != nil {
 | 
						|
		return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
 | 
						|
	}
 | 
						|
	return csrf, nil
 | 
						|
}
 | 
						|
 | 
						|
// cookieName returns the CSRF cookie's name
 | 
						|
func (c *csrf) cookieName() string {
 | 
						|
	stateSubstring := ""
 | 
						|
	if c.cookieOpts.CSRFPerRequest {
 | 
						|
		stateSubstring = encryption.HashNonce(c.OAuthState)[0 : csrfStateLength-1]
 | 
						|
	}
 | 
						|
	return csrfCookieName(c.cookieOpts, stateSubstring)
 | 
						|
}
 | 
						|
 | 
						|
func csrfCookieName(opts *options.Cookie, stateSubstring string) string {
 | 
						|
	if stateSubstring == "" {
 | 
						|
		return fmt.Sprintf("%v_csrf", opts.Name)
 | 
						|
	}
 | 
						|
	return fmt.Sprintf("%v_%v_csrf", opts.Name, stateSubstring)
 | 
						|
}
 | 
						|
 | 
						|
// ExtractStateSubstring extract the initial state characters, to add it to the CSRF cookie name
 | 
						|
func ExtractStateSubstring(state string) string {
 | 
						|
	lastChar := csrfStateLength - 1
 | 
						|
	stateSubstring := ""
 | 
						|
	if lastChar <= len(state) {
 | 
						|
		stateSubstring = state[0:lastChar]
 | 
						|
	}
 | 
						|
	return stateSubstring
 | 
						|
}
 | 
						|
 | 
						|
func encrypt(data []byte, opts *options.Cookie) ([]byte, error) {
 | 
						|
	cipher, err := makeCipher(opts)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return cipher.Encrypt(data)
 | 
						|
}
 | 
						|
 | 
						|
func decrypt(data []byte, opts *options.Cookie) ([]byte, error) {
 | 
						|
	cipher, err := makeCipher(opts)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	return cipher.Decrypt(data)
 | 
						|
}
 | 
						|
 | 
						|
func makeCipher(opts *options.Cookie) (encryption.Cipher, error) {
 | 
						|
	secret, err := opts.GetSecret()
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("error getting cookie secret: %v", err)
 | 
						|
	}
 | 
						|
	return encryption.NewCFBCipher(encryption.SecretBytes(secret))
 | 
						|
}
 |