182 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			182 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
| package cookies
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"net/http"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/clock"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
 | |
| 	"github.com/vmihailenco/msgpack/v4"
 | |
| )
 | |
| 
 | |
| // CSRF manages various nonces stored in the CSRF cookie during the initial
 | |
| // authentication flows.
 | |
| type CSRF interface {
 | |
| 	HashOAuthState() string
 | |
| 	HashOIDCNonce() string
 | |
| 	CheckOAuthState(string) bool
 | |
| 	CheckOIDCNonce(string) bool
 | |
| 
 | |
| 	SetSessionNonce(s *sessions.SessionState)
 | |
| 
 | |
| 	SetCookie(http.ResponseWriter, *http.Request) (*http.Cookie, error)
 | |
| 	ClearCookie(http.ResponseWriter, *http.Request)
 | |
| }
 | |
| 
 | |
| type csrf struct {
 | |
| 	// OAuthState holds the OAuth2 state parameter's nonce component set in the
 | |
| 	// initial authentication request and mirrored back in the callback
 | |
| 	// redirect from the IdP for CSRF protection.
 | |
| 	OAuthState []byte `msgpack:"s,omitempty"`
 | |
| 
 | |
| 	// OIDCNonce holds the OIDC nonce parameter used in the initial authentication
 | |
| 	// and then set in all subsequent OIDC ID Tokens as the nonce claim. This
 | |
| 	// is used to mitigate replay attacks.
 | |
| 	OIDCNonce []byte `msgpack:"n,omitempty"`
 | |
| 
 | |
| 	builder          Builder
 | |
| 	encryptionSecret string
 | |
| 	time             clock.Clock
 | |
| }
 | |
| 
 | |
| // NewCSRF creates a CSRF with random nonces
 | |
| func NewCSRF(builder Builder, secret string) (CSRF, error) {
 | |
| 	state, err := encryption.Nonce()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	nonce, err := encryption.Nonce()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return &csrf{
 | |
| 		OAuthState: state,
 | |
| 		OIDCNonce:  nonce,
 | |
| 
 | |
| 		builder:          builder,
 | |
| 		encryptionSecret: secret,
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| // LoadCSRFCookie loads a CSRF object from a request's CSRF cookie
 | |
| func LoadCSRFCookie(req *http.Request, builder Builder, secret string) (CSRF, error) {
 | |
| 	cookieValue, err := builder.ValidateRequest(req)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("failed to validate CSRF cookie value: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return decodeCSRFCookie(cookieValue, builder, secret)
 | |
| }
 | |
| 
 | |
| // HashOAuthState returns the hash of the OAuth state nonce
 | |
| func (c *csrf) HashOAuthState() string {
 | |
| 	return encryption.HashNonce(c.OAuthState)
 | |
| }
 | |
| 
 | |
| // HashOIDCNonce returns the hash of the OIDC nonce
 | |
| func (c *csrf) HashOIDCNonce() string {
 | |
| 	return encryption.HashNonce(c.OIDCNonce)
 | |
| }
 | |
| 
 | |
| // CheckOAuthState compares the OAuth state nonce against a potential
 | |
| // hash of it
 | |
| func (c *csrf) CheckOAuthState(hashed string) bool {
 | |
| 	return encryption.CheckNonce(c.OAuthState, hashed)
 | |
| }
 | |
| 
 | |
| // CheckOIDCNonce compares the OIDC nonce against a potential hash of it
 | |
| func (c *csrf) CheckOIDCNonce(hashed string) bool {
 | |
| 	return encryption.CheckNonce(c.OIDCNonce, hashed)
 | |
| }
 | |
| 
 | |
| // SetSessionNonce sets the OIDCNonce on a SessionState
 | |
| func (c *csrf) SetSessionNonce(s *sessions.SessionState) {
 | |
| 	s.Nonce = c.OIDCNonce
 | |
| }
 | |
| 
 | |
| // SetCookie encodes the CSRF to a signed cookie and sets it on the ResponseWriter
 | |
| func (c *csrf) SetCookie(rw http.ResponseWriter, req *http.Request) (*http.Cookie, error) {
 | |
| 	encoded, err := c.encodeCookie()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	cookie, err := c.builder.
 | |
| 		WithStart(c.time.Now()).
 | |
| 		MakeCookie(req, encoded)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	http.SetCookie(rw, cookie)
 | |
| 
 | |
| 	return cookie, nil
 | |
| }
 | |
| 
 | |
| // ClearCookie removes the CSRF cookie
 | |
| func (c *csrf) ClearCookie(rw http.ResponseWriter, req *http.Request) error {
 | |
| 	cookie, err := c.builder.
 | |
| 		WithExpiration(time.Hour*-1).
 | |
| 		WithStart(c.time.Now()).
 | |
| 		MakeCookie(req, "")
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("could not create cookie: %v", err)
 | |
| 	}
 | |
| 	http.SetCookie(rw, cookie)
 | |
| }
 | |
| 
 | |
| // encodeCookie MessagePack encodes and encrypts the CSRF and then creates a
 | |
| // signed cookie value
 | |
| func (c *csrf) encodeCookie() (string, error) {
 | |
| 	packed, err := msgpack.Marshal(c)
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("error marshalling CSRF to msgpack: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	encrypted, err := encrypt(packed, c.cookieOpts)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return string(encrypted), nil
 | |
| }
 | |
| 
 | |
| // decodeCSRFCookie validates the signature then decrypts and decodes a CSRF
 | |
| // cookie into a CSRF struct
 | |
| func decodeCSRFCookie(cookieValue string, builder Builder, secret string) (*csrf, error) {
 | |
| 	decrypted, err := decrypt([]byte(cookieValue), secret)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// Valid cookie, Unmarshal the CSRF
 | |
| 	csrf := &csrf{builder: builder, encryptionSecret: secret}
 | |
| 	if err := msgpack.Unmarshal(decrypted, csrf); err != nil {
 | |
| 		return nil, fmt.Errorf("error unmarshalling data to CSRF: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	return csrf, nil
 | |
| }
 | |
| 
 | |
| func encrypt(data []byte, secret string) ([]byte, error) {
 | |
| 	cipher, err := makeCipher(secret)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return cipher.Encrypt(data)
 | |
| }
 | |
| 
 | |
| func decrypt(data []byte, secret string) ([]byte, error) {
 | |
| 	cipher, err := makeCipher(secret)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return cipher.Decrypt(data)
 | |
| }
 | |
| 
 | |
| func makeCipher(secret string) (encryption.Cipher, error) {
 | |
| 	return encryption.NewCFBCipher(encryption.SecretBytes(secret))
 | |
| }
 |