258 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			258 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
| package cookies
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net/http"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
 | |
| 	"github.com/vmihailenco/msgpack/v5"
 | |
| )
 | |
| 
 | |
| // CSRF manages various nonces stored in the CSRF cookie during the initial
 | |
| // authentication flows.
 | |
| type CSRF interface {
 | |
| 	HashOAuthState() string
 | |
| 	HashOIDCNonce() string
 | |
| 	CheckOAuthState(string) bool
 | |
| 	CheckOIDCNonce(string) bool
 | |
| 	GetCodeVerifier() string
 | |
| 
 | |
| 	SetSessionNonce(s *sessions.SessionState)
 | |
| 
 | |
| 	SetCookie(http.ResponseWriter, *http.Request) (*http.Cookie, error)
 | |
| 	ClearCookie(http.ResponseWriter, *http.Request)
 | |
| }
 | |
| 
 | |
| type csrf struct {
 | |
| 	// OAuthState holds the OAuth2 state parameter's nonce component set in the
 | |
| 	// initial authentication request and mirrored back in the callback
 | |
| 	// redirect from the IdP for CSRF protection.
 | |
| 	OAuthState []byte `msgpack:"s,omitempty"`
 | |
| 
 | |
| 	// OIDCNonce holds the OIDC nonce parameter used in the initial authentication
 | |
| 	// and then set in all subsequent OIDC ID Tokens as the nonce claim. This
 | |
| 	// is used to mitigate replay attacks.
 | |
| 	OIDCNonce []byte `msgpack:"n,omitempty"`
 | |
| 
 | |
| 	// CodeVerifier holds the unobfuscated PKCE code verification string
 | |
| 	// which is used to compare the code challenge when exchanging the
 | |
| 	// authentication code.
 | |
| 	CodeVerifier string `msgpack:"cv,omitempty"`
 | |
| 
 | |
| 	cookieOpts *options.Cookie
 | |
| 	time       clock.Clock
 | |
| }
 | |
| 
 | |
| // csrtStateTrim will indicate the length of the state trimmed for the name of the csrf cookie
 | |
| const csrfStateLength int = 9
 | |
| 
 | |
| // NewCSRF creates a CSRF with random nonces
 | |
| func NewCSRF(opts *options.Cookie, codeVerifier string) (CSRF, error) {
 | |
| 	state, err := encryption.Nonce(32)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	nonce, err := encryption.Nonce(32)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &csrf{
 | |
| 		OAuthState:   state,
 | |
| 		OIDCNonce:    nonce,
 | |
| 		CodeVerifier: codeVerifier,
 | |
| 
 | |
| 		cookieOpts: opts,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // LoadCSRFCookie loads a CSRF object from a request's CSRF cookie
 | |
| func LoadCSRFCookie(req *http.Request, opts *options.Cookie) (CSRF, error) {
 | |
| 	cookieName := GenerateCookieName(req, opts)
 | |
| 
 | |
| 	cookies := req.Cookies()
 | |
| 	for _, cookie := range cookies {
 | |
| 		if cookie.Name != cookieName {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		csrf, err := decodeCSRFCookie(cookie, opts)
 | |
| 		if err != nil {
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		return csrf, nil
 | |
| 	}
 | |
| 
 | |
| 	return nil, errors.New("CSRF cookie not found")
 | |
| }
 | |
| 
 | |
| // GenerateCookieName in case cookie options state that CSRF cookie has fixed name then set fixed name, otherwise
 | |
| // build name based on the state
 | |
| func GenerateCookieName(req *http.Request, opts *options.Cookie) string {
 | |
| 	stateSubstring := ""
 | |
| 	if opts.CSRFPerRequest {
 | |
| 		// csrfCookieName will include a substring of the state to enable multiple csrf cookies
 | |
| 		// in case of parallel requests
 | |
| 		stateSubstring = ExtractStateSubstring(req)
 | |
| 	}
 | |
| 	return csrfCookieName(opts, stateSubstring)
 | |
| }
 | |
| 
 | |
| func (c *csrf) GetCodeVerifier() string {
 | |
| 	return c.CodeVerifier
 | |
| }
 | |
| 
 | |
| // HashOAuthState returns the hash of the OAuth state nonce
 | |
| func (c *csrf) HashOAuthState() string {
 | |
| 	return encryption.HashNonce(c.OAuthState)
 | |
| }
 | |
| 
 | |
| // HashOIDCNonce returns the hash of the OIDC nonce
 | |
| func (c *csrf) HashOIDCNonce() string {
 | |
| 	return encryption.HashNonce(c.OIDCNonce)
 | |
| }
 | |
| 
 | |
| // CheckOAuthState compares the OAuth state nonce against a potential
 | |
| // hash of it
 | |
| func (c *csrf) CheckOAuthState(hashed string) bool {
 | |
| 	return encryption.CheckNonce(c.OAuthState, hashed)
 | |
| }
 | |
| 
 | |
| // CheckOIDCNonce compares the OIDC nonce against a potential hash of it
 | |
| func (c *csrf) CheckOIDCNonce(hashed string) bool {
 | |
| 	return encryption.CheckNonce(c.OIDCNonce, hashed)
 | |
| }
 | |
| 
 | |
| // SetSessionNonce sets the OIDCNonce on a SessionState
 | |
| func (c *csrf) SetSessionNonce(s *sessions.SessionState) {
 | |
| 	s.Nonce = c.OIDCNonce
 | |
| }
 | |
| 
 | |
| // SetCookie encodes the CSRF to a signed cookie and sets it on the ResponseWriter
 | |
| func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cookie, error) {
 | |
| 	encoded, err := c.encodeCookie()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	cookie := MakeCookieFromOptions(
 | |
| 		req,
 | |
| 		c.cookieName(),
 | |
| 		encoded,
 | |
| 		c.cookieOpts,
 | |
| 		c.cookieOpts.CSRFExpire,
 | |
| 		c.time.Now(),
 | |
| 	)
 | |
| 	http.SetCookie(rw, cookie)
 | |
| 
 | |
| 	return cookie, nil
 | |
| }
 | |
| 
 | |
| // ClearCookie removes the CSRF cookie
 | |
| func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) {
 | |
| 	http.SetCookie(rw, MakeCookieFromOptions(
 | |
| 		req,
 | |
| 		c.cookieName(),
 | |
| 		"",
 | |
| 		c.cookieOpts,
 | |
| 		time.Hour*-1,
 | |
| 		c.time.Now(),
 | |
| 	))
 | |
| }
 | |
| 
 | |
| // encodeCookie MessagePack encodes and encrypts the CSRF and then creates a
 | |
| // signed cookie value
 | |
| func (c *csrf) encodeCookie() (string, error) {
 | |
| 	packed, err := msgpack.Marshal(c)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("error marshalling CSRF to msgpack: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	encrypted, err := encrypt(packed, c.cookieOpts)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return encryption.SignedValue(c.cookieOpts.Secret, c.cookieName(), encrypted, c.time.Now())
 | |
| }
 | |
| 
 | |
| // decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
 | |
| // cookie into a CSRF struct
 | |
| func decodeCSRFCookie(cookie *http.Cookie, opts *options.Cookie) (*csrf, error) {
 | |
| 	val, _, ok := encryption.Validate(cookie, opts.Secret, opts.Expire)
 | |
| 	if !ok {
 | |
| 		return nil, errors.New("CSRF cookie failed validation")
 | |
| 	}
 | |
| 
 | |
| 	decrypted, err := decrypt(val, opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// Valid cookie, Unmarshal the CSRF
 | |
| 	csrf := &csrf{cookieOpts: opts}
 | |
| 	err = msgpack.Unmarshal(decrypted, csrf)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return csrf, nil
 | |
| }
 | |
| 
 | |
| // cookieName returns the CSRF cookie's name
 | |
| func (c *csrf) cookieName() string {
 | |
| 	stateSubstring := ""
 | |
| 	if c.cookieOpts.CSRFPerRequest {
 | |
| 		stateSubstring = encryption.HashNonce(c.OAuthState)[0 : csrfStateLength-1]
 | |
| 	}
 | |
| 	return csrfCookieName(c.cookieOpts, stateSubstring)
 | |
| }
 | |
| 
 | |
| func csrfCookieName(opts *options.Cookie, stateSubstring string) string {
 | |
| 	if stateSubstring == "" {
 | |
| 		return fmt.Sprintf("%v_csrf", opts.Name)
 | |
| 	}
 | |
| 	return fmt.Sprintf("%v_csrf_%v", opts.Name, stateSubstring)
 | |
| }
 | |
| 
 | |
| // ExtractStateSubstring extract the initial state characters, to add it to the CSRF cookie name
 | |
| func ExtractStateSubstring(req *http.Request) string {
 | |
| 	lastChar := csrfStateLength - 1
 | |
| 	stateSubstring := ""
 | |
| 
 | |
| 	state := req.URL.Query()["state"]
 | |
| 	if state[0] != "" {
 | |
| 		state := state[0]
 | |
| 		if lastChar <= len(state) {
 | |
| 			stateSubstring = state[0:lastChar]
 | |
| 		}
 | |
| 	}
 | |
| 	return stateSubstring
 | |
| }
 | |
| 
 | |
| func encrypt(data []byte, opts *options.Cookie) ([]byte, error) {
 | |
| 	cipher, err := makeCipher(opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return cipher.Encrypt(data)
 | |
| }
 | |
| 
 | |
| func decrypt(data []byte, opts *options.Cookie) ([]byte, error) {
 | |
| 	cipher, err := makeCipher(opts)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return cipher.Decrypt(data)
 | |
| }
 | |
| 
 | |
| func makeCipher(opts *options.Cookie) (encryption.Cipher, error) {
 | |
| 	return encryption.NewCFBCipher(encryption.SecretBytes(opts.Secret))
 | |
| }
 |