162 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
	
| package persistence
 | |
| 
 | |
| import (
 | |
| 	"encoding/base64"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | |
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | |
| 	. "github.com/onsi/ginkgo"
 | |
| 	. "github.com/onsi/ginkgo/extensions/table"
 | |
| 	. "github.com/onsi/gomega"
 | |
| )
 | |
| 
 | |
| var _ = Describe("Session Ticket Tests", func() {
 | |
| 	Context("encodeTicket & decodeTicket", func() {
 | |
| 		type ticketTableInput struct {
 | |
| 			ticket        *ticket
 | |
| 			encodedTicket string
 | |
| 			expectedError error
 | |
| 		}
 | |
| 
 | |
| 		DescribeTable("encodeTicket should decodeTicket back when valid",
 | |
| 			func(in ticketTableInput) {
 | |
| 				if in.ticket != nil {
 | |
| 					enc := in.ticket.encodeTicket()
 | |
| 					Expect(enc).To(Equal(in.encodedTicket))
 | |
| 
 | |
| 					dec, err := decodeTicket(enc, in.ticket.options)
 | |
| 					Expect(err).ToNot(HaveOccurred())
 | |
| 					Expect(dec).To(Equal(in.ticket))
 | |
| 				} else {
 | |
| 					_, err := decodeTicket(in.encodedTicket, nil)
 | |
| 					Expect(err).To(MatchError(in.expectedError))
 | |
| 				}
 | |
| 			},
 | |
| 			Entry("with a valid ticket", ticketTableInput{
 | |
| 				ticket: &ticket{
 | |
| 					id:     "dummy-0123456789abcdef",
 | |
| 					secret: []byte("0123456789abcdef"),
 | |
| 					options: &options.Cookie{
 | |
| 						Name: "dummy",
 | |
| 					},
 | |
| 				},
 | |
| 				encodedTicket: fmt.Sprintf("%s.%s",
 | |
| 					"dummy-0123456789abcdef",
 | |
| 					base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))),
 | |
| 				expectedError: nil,
 | |
| 			}),
 | |
| 			Entry("with an invalid encoded ticket with 1 part", ticketTableInput{
 | |
| 				ticket:        nil,
 | |
| 				encodedTicket: "dummy-0123456789abcdef",
 | |
| 				expectedError: errors.New("failed to decode ticket"),
 | |
| 			}),
 | |
| 			Entry("with an invalid base64 encoded secret", ticketTableInput{
 | |
| 				ticket:        nil,
 | |
| 				encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$",
 | |
| 				expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"),
 | |
| 			}),
 | |
| 		)
 | |
| 	})
 | |
| 
 | |
| 	Context("saveSession", func() {
 | |
| 		It("uses the passed save function", func() {
 | |
| 			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			c, err := t.makeCipher()
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			ss := &sessions.SessionState{User: "foobar"}
 | |
| 			store := map[string][]byte{}
 | |
| 			err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error {
 | |
| 				store[k] = v
 | |
| 				return nil
 | |
| 			})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			stored, err := sessions.DecodeSessionState(store[t.id], c, false)
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 			Expect(stored).To(Equal(ss))
 | |
| 		})
 | |
| 
 | |
| 		It("errors when the saveFunc errors", func() {
 | |
| 			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			err = t.saveSession(
 | |
| 				&sessions.SessionState{User: "foobar"},
 | |
| 				func(k string, v []byte, e time.Duration) error {
 | |
| 					return errors.New("save error")
 | |
| 				})
 | |
| 			Expect(err).To(MatchError(errors.New("save error")))
 | |
| 		})
 | |
| 	})
 | |
| 
 | |
| 	Context("loadSession", func() {
 | |
| 		It("uses the passed load function", func() {
 | |
| 			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			c, err := t.makeCipher()
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			ss := &sessions.SessionState{
 | |
| 				User: "foobar",
 | |
| 				Lock: &sessions.NoOpLock{},
 | |
| 			}
 | |
| 			loadedSession, err := t.loadSession(
 | |
| 				func(k string) ([]byte, error) {
 | |
| 					return ss.EncodeSessionState(c, false)
 | |
| 				},
 | |
| 				func(k string) sessions.Lock {
 | |
| 					return &sessions.NoOpLock{}
 | |
| 				})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 			Expect(loadedSession).To(Equal(ss))
 | |
| 		})
 | |
| 
 | |
| 		It("errors when the loadFunc errors", func() {
 | |
| 			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			data, err := t.loadSession(
 | |
| 				func(k string) ([]byte, error) {
 | |
| 					return nil, errors.New("load error")
 | |
| 				},
 | |
| 				func(k string) sessions.Lock {
 | |
| 					return &sessions.NoOpLock{}
 | |
| 				})
 | |
| 			Expect(data).To(BeNil())
 | |
| 			Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error")))
 | |
| 		})
 | |
| 	})
 | |
| 
 | |
| 	Context("clearSession", func() {
 | |
| 		It("uses the passed clear function", func() {
 | |
| 			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			var tracker string
 | |
| 			err = t.clearSession(func(k string) error {
 | |
| 				tracker = k
 | |
| 				return nil
 | |
| 			})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 			Expect(tracker).To(Equal(t.id))
 | |
| 		})
 | |
| 
 | |
| 		It("errors when the clearFunc errors", func() {
 | |
| 			t, err := newTicket(&options.Cookie{Name: "dummy"})
 | |
| 			Expect(err).ToNot(HaveOccurred())
 | |
| 
 | |
| 			err = t.clearSession(func(k string) error {
 | |
| 				return errors.New("clear error")
 | |
| 			})
 | |
| 			Expect(err).To(MatchError(errors.New("clear error")))
 | |
| 		})
 | |
| 	})
 | |
| })
 |