217 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			217 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
package persistence
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/aes"
 | 
						|
	"crypto/cipher"
 | 
						|
	"crypto/rand"
 | 
						|
	"encoding/base64"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/options"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
 | 
						|
	. "github.com/onsi/ginkgo"
 | 
						|
	. "github.com/onsi/ginkgo/extensions/table"
 | 
						|
	. "github.com/onsi/gomega"
 | 
						|
)
 | 
						|
 | 
						|
var _ = Describe("Session Ticket Tests", func() {
 | 
						|
	Context("encodeTicket & decodeTicket", func() {
 | 
						|
		type ticketTableInput struct {
 | 
						|
			ticket        *ticket
 | 
						|
			encodedTicket string
 | 
						|
			expectedError error
 | 
						|
		}
 | 
						|
 | 
						|
		DescribeTable("encodeTicket should decodeTicket back when valid",
 | 
						|
			func(in ticketTableInput) {
 | 
						|
				if in.ticket != nil {
 | 
						|
					enc := in.ticket.encodeTicket()
 | 
						|
					Expect(enc).To(Equal(in.encodedTicket))
 | 
						|
 | 
						|
					dec, err := decodeTicket(enc, in.ticket.options)
 | 
						|
					Expect(err).ToNot(HaveOccurred())
 | 
						|
					Expect(dec).To(Equal(in.ticket))
 | 
						|
				} else {
 | 
						|
					_, err := decodeTicket(in.encodedTicket, nil)
 | 
						|
					Expect(err).To(MatchError(in.expectedError))
 | 
						|
				}
 | 
						|
			},
 | 
						|
			Entry("with a valid ticket", ticketTableInput{
 | 
						|
				ticket: &ticket{
 | 
						|
					id:     "dummy-0123456789abcdef",
 | 
						|
					secret: []byte("0123456789abcdef"),
 | 
						|
					options: &options.Cookie{
 | 
						|
						Name: "dummy",
 | 
						|
					},
 | 
						|
				},
 | 
						|
				encodedTicket: fmt.Sprintf("%s.%s",
 | 
						|
					"dummy-0123456789abcdef",
 | 
						|
					base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))),
 | 
						|
				expectedError: nil,
 | 
						|
			}),
 | 
						|
			Entry("with an invalid encoded ticket with 1 part", ticketTableInput{
 | 
						|
				ticket:        nil,
 | 
						|
				encodedTicket: "dummy-0123456789abcdef",
 | 
						|
				expectedError: errors.New("failed to decode ticket"),
 | 
						|
			}),
 | 
						|
			Entry("with an invalid base64 encoded secret", ticketTableInput{
 | 
						|
				ticket:        nil,
 | 
						|
				encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$",
 | 
						|
				expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"),
 | 
						|
			}),
 | 
						|
		)
 | 
						|
	})
 | 
						|
 | 
						|
	Context("saveSession", func() {
 | 
						|
		It("uses the passed save function", func() {
 | 
						|
			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			c, err := t.makeCipher()
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			ss := &sessions.SessionState{User: "foobar"}
 | 
						|
			store := map[string][]byte{}
 | 
						|
			err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error {
 | 
						|
				store[k] = v
 | 
						|
				return nil
 | 
						|
			})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			stored, err := sessions.DecodeSessionState(store[t.id], c, false)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
			Expect(stored).To(Equal(ss))
 | 
						|
		})
 | 
						|
 | 
						|
		It("errors when the saveFunc errors", func() {
 | 
						|
			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			err = t.saveSession(
 | 
						|
				&sessions.SessionState{User: "foobar"},
 | 
						|
				func(k string, v []byte, e time.Duration) error {
 | 
						|
					return errors.New("save error")
 | 
						|
				})
 | 
						|
			Expect(err).To(MatchError(errors.New("save error")))
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("loadSession", func() {
 | 
						|
		It("uses the passed load function", func() {
 | 
						|
			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			c, err := t.makeCipher()
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			ss := &sessions.SessionState{User: "foobar"}
 | 
						|
			loadedSession, err := t.loadSession(func(k string) ([]byte, error) {
 | 
						|
				return ss.EncodeSessionState(c, false)
 | 
						|
			})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
			Expect(loadedSession).To(Equal(ss))
 | 
						|
		})
 | 
						|
 | 
						|
		It("errors when the loadFunc errors", func() {
 | 
						|
			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			data, err := t.loadSession(func(k string) ([]byte, error) {
 | 
						|
				return nil, errors.New("load error")
 | 
						|
			})
 | 
						|
			Expect(data).To(BeNil())
 | 
						|
			Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error")))
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("clearSession", func() {
 | 
						|
		It("uses the passed clear function", func() {
 | 
						|
			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			var tracker string
 | 
						|
			err = t.clearSession(func(k string) error {
 | 
						|
				tracker = k
 | 
						|
				return nil
 | 
						|
			})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
			Expect(tracker).To(Equal(t.id))
 | 
						|
		})
 | 
						|
 | 
						|
		It("errors when the clearFunc errors", func() {
 | 
						|
			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			err = t.clearSession(func(k string) error {
 | 
						|
				return errors.New("clear error")
 | 
						|
			})
 | 
						|
			Expect(err).To(MatchError(errors.New("clear error")))
 | 
						|
		})
 | 
						|
	})
 | 
						|
})
 | 
						|
 | 
						|
// TestLegacyV5DecodeSession tests the fallback to LegacyV5DecodeSession
 | 
						|
// when a V5 encoded session is in Redis
 | 
						|
//
 | 
						|
// TODO (@NickMeves): Remove when this is deprecated (likely V7)
 | 
						|
func Test_legacyV5LoadSession(t *testing.T) {
 | 
						|
	testCases, _, _ := sessions.CreateLegacyV5TestCases(t)
 | 
						|
 | 
						|
	for testName, tc := range testCases {
 | 
						|
		t.Run(testName, func(t *testing.T) {
 | 
						|
			g := NewWithT(t)
 | 
						|
 | 
						|
			secret := make([]byte, aes.BlockSize)
 | 
						|
			_, err := io.ReadFull(rand.Reader, secret)
 | 
						|
			g.Expect(err).ToNot(HaveOccurred())
 | 
						|
			tckt := &ticket{
 | 
						|
				secret: secret,
 | 
						|
				options: &options.Cookie{
 | 
						|
					Secret: base64.RawURLEncoding.EncodeToString([]byte(sessions.LegacyV5TestSecret)),
 | 
						|
				},
 | 
						|
			}
 | 
						|
 | 
						|
			encrypted, err := legacyStoreValue(tc.Input, tckt.secret)
 | 
						|
			g.Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			ss, err := tckt.legacyV5LoadSession(encrypted)
 | 
						|
			if tc.Error {
 | 
						|
				g.Expect(err).To(HaveOccurred())
 | 
						|
				g.Expect(ss).To(BeNil())
 | 
						|
				return
 | 
						|
			}
 | 
						|
			g.Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			// Compare sessions without *time.Time fields
 | 
						|
			exp := *tc.Output
 | 
						|
			exp.CreatedAt = nil
 | 
						|
			exp.ExpiresOn = nil
 | 
						|
			act := *ss
 | 
						|
			act.CreatedAt = nil
 | 
						|
			act.ExpiresOn = nil
 | 
						|
			g.Expect(exp).To(Equal(act))
 | 
						|
		})
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// legacyStoreValue implements the legacy V5 Redis persistence AES-CFB value encryption
 | 
						|
//
 | 
						|
// TODO (@NickMeves): Remove when this is deprecated (likely V7)
 | 
						|
func legacyStoreValue(value string, ticketSecret []byte) ([]byte, error) {
 | 
						|
	ciphertext := make([]byte, len(value))
 | 
						|
	block, err := aes.NewCipher(ticketSecret)
 | 
						|
	if err != nil {
 | 
						|
		return nil, fmt.Errorf("error initiating cipher block: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// Use secret as the Initialization Vector too, because each entry has it's own key
 | 
						|
	stream := cipher.NewCFBEncrypter(block, ticketSecret)
 | 
						|
	stream.XORKeyStream(ciphertext, []byte(value))
 | 
						|
 | 
						|
	return ciphertext, nil
 | 
						|
}
 |