425 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			425 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
package tests
 | 
						|
 | 
						|
import (
 | 
						|
	"crypto/rand"
 | 
						|
	"net/http"
 | 
						|
	"net/http/httptest"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | 
						|
	sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
 | 
						|
	cookiesapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
 | 
						|
	. "github.com/onsi/ginkgo"
 | 
						|
	. "github.com/onsi/gomega"
 | 
						|
)
 | 
						|
 | 
						|
// testInput is passed to test function as a pointer.
 | 
						|
// This allows BeforeEach blocks to initialise and use these values after
 | 
						|
// Ginkgo has unpacked the tests.
 | 
						|
// Interfaces have to be wrapped in closures otherwise nil pointers are thrown.
 | 
						|
type testInput struct {
 | 
						|
	cookieOpts            *options.Cookie
 | 
						|
	ss                    sessionStoreFunc
 | 
						|
	session               *sessionsapi.SessionState
 | 
						|
	request               *http.Request
 | 
						|
	response              *httptest.ResponseRecorder
 | 
						|
	persistentFastForward PersistentStoreFastForwardFunc
 | 
						|
}
 | 
						|
 | 
						|
// sessionStoreFunc is used in testInput to wrap the SessionStore interface.
 | 
						|
type sessionStoreFunc func() sessionsapi.SessionStore
 | 
						|
 | 
						|
// PersistentStoreFastForwardFunc is used to adjust the time of the persistent
 | 
						|
// store to fast forward expiry of sessions.
 | 
						|
type PersistentStoreFastForwardFunc func(time.Duration) error
 | 
						|
 | 
						|
// NewSessionStoreFunc allows any session store implementation to configure their
 | 
						|
// own session store before each test.
 | 
						|
type NewSessionStoreFunc func(sessionOpts *options.SessionOptions, cookieOpts *options.Cookie) (sessionsapi.SessionStore, error)
 | 
						|
 | 
						|
func RunSessionStoreTests(newSS NewSessionStoreFunc, persistentFastForward PersistentStoreFastForwardFunc) {
 | 
						|
	Describe("Session Store Suite", func() {
 | 
						|
		var opts *options.SessionOptions
 | 
						|
		var ss sessionsapi.SessionStore
 | 
						|
		var input testInput
 | 
						|
		var cookieSecret []byte
 | 
						|
 | 
						|
		getSessionStore := func() sessionsapi.SessionStore {
 | 
						|
			return ss
 | 
						|
		}
 | 
						|
 | 
						|
		BeforeEach(func() {
 | 
						|
			ss = nil
 | 
						|
			opts = &options.SessionOptions{}
 | 
						|
 | 
						|
			// A secret is required to create a Cipher, validation ensures it is the correct
 | 
						|
			// length before a session store is initialised.
 | 
						|
			cookieSecret = make([]byte, 32)
 | 
						|
			_, err := rand.Read(cookieSecret)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			// Set default options in CookieOptions
 | 
						|
			cookieOpts := &options.Cookie{
 | 
						|
				Name:     "_oauth2_proxy",
 | 
						|
				Path:     "/",
 | 
						|
				Expire:   time.Duration(168) * time.Hour,
 | 
						|
				Refresh:  time.Duration(1) * time.Hour,
 | 
						|
				Secure:   true,
 | 
						|
				HTTPOnly: true,
 | 
						|
				SameSite: "",
 | 
						|
				Secret:   string(cookieSecret),
 | 
						|
			}
 | 
						|
 | 
						|
			expires := time.Now().Add(1 * time.Hour)
 | 
						|
			session := &sessionsapi.SessionState{
 | 
						|
				AccessToken:  "AccessToken",
 | 
						|
				IDToken:      "IDToken",
 | 
						|
				ExpiresOn:    &expires,
 | 
						|
				RefreshToken: "RefreshToken",
 | 
						|
				Email:        "john.doe@example.com",
 | 
						|
				User:         "john.doe",
 | 
						|
			}
 | 
						|
 | 
						|
			request := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
			response := httptest.NewRecorder()
 | 
						|
 | 
						|
			input = testInput{
 | 
						|
				cookieOpts:            cookieOpts,
 | 
						|
				ss:                    getSessionStore,
 | 
						|
				session:               session,
 | 
						|
				request:               request,
 | 
						|
				response:              response,
 | 
						|
				persistentFastForward: persistentFastForward,
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		Context("with default options", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				var err error
 | 
						|
				ss, err = newSS(opts, input.cookieOpts)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
			})
 | 
						|
 | 
						|
			SessionStoreInterfaceTests(&input)
 | 
						|
			if persistentFastForward != nil {
 | 
						|
				PersistentSessionStoreInterfaceTests(&input)
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		Context("with non-default options", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				input.cookieOpts = &options.Cookie{
 | 
						|
					Name:     "_cookie_name",
 | 
						|
					Path:     "/path",
 | 
						|
					Expire:   time.Duration(72) * time.Hour,
 | 
						|
					Refresh:  time.Duration(2) * time.Hour,
 | 
						|
					Secure:   false,
 | 
						|
					HTTPOnly: false,
 | 
						|
					Domains:  []string{"example.com"},
 | 
						|
					SameSite: "strict",
 | 
						|
					Secret:   string(cookieSecret),
 | 
						|
				}
 | 
						|
 | 
						|
				var err error
 | 
						|
				ss, err = newSS(opts, input.cookieOpts)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
			})
 | 
						|
 | 
						|
			SessionStoreInterfaceTests(&input)
 | 
						|
			if persistentFastForward != nil {
 | 
						|
				PersistentSessionStoreInterfaceTests(&input)
 | 
						|
			}
 | 
						|
		})
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func CheckCookieOptions(in *testInput) {
 | 
						|
	Context("the cookies returned", func() {
 | 
						|
		var cookies []*http.Cookie
 | 
						|
		BeforeEach(func() {
 | 
						|
			cookies = in.response.Result().Cookies()
 | 
						|
		})
 | 
						|
 | 
						|
		It("have the correct name set", func() {
 | 
						|
			if len(cookies) == 1 {
 | 
						|
				Expect(cookies[0].Name).To(Equal(in.cookieOpts.Name))
 | 
						|
			} else {
 | 
						|
				for _, cookie := range cookies {
 | 
						|
					Expect(cookie.Name).To(ContainSubstring(in.cookieOpts.Name))
 | 
						|
				}
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		It("have the correct path set", func() {
 | 
						|
			for _, cookie := range cookies {
 | 
						|
				Expect(cookie.Path).To(Equal(in.cookieOpts.Path))
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		It("have the correct domain set", func() {
 | 
						|
			for _, cookie := range cookies {
 | 
						|
				specifiedDomain := ""
 | 
						|
				if len(in.cookieOpts.Domains) > 0 {
 | 
						|
					specifiedDomain = in.cookieOpts.Domains[0]
 | 
						|
				}
 | 
						|
				Expect(cookie.Domain).To(Equal(specifiedDomain))
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		It("have the correct HTTPOnly set", func() {
 | 
						|
			for _, cookie := range cookies {
 | 
						|
				Expect(cookie.HttpOnly).To(Equal(in.cookieOpts.HTTPOnly))
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		It("have the correct secure set", func() {
 | 
						|
			for _, cookie := range cookies {
 | 
						|
				Expect(cookie.Secure).To(Equal(in.cookieOpts.Secure))
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		It("have the correct SameSite set", func() {
 | 
						|
			for _, cookie := range cookies {
 | 
						|
				Expect(cookie.SameSite).To(Equal(cookiesapi.ParseSameSite(in.cookieOpts.SameSite)))
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		It("have a signature timestamp matching session.CreatedAt", func() {
 | 
						|
			for _, cookie := range cookies {
 | 
						|
				if cookie.Value != "" {
 | 
						|
					parts := strings.Split(cookie.Value, "|")
 | 
						|
					Expect(parts).To(HaveLen(3))
 | 
						|
					Expect(parts[1]).To(Equal(strconv.Itoa(int(in.session.CreatedAt.Unix()))))
 | 
						|
				}
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func PersistentSessionStoreInterfaceTests(in *testInput) {
 | 
						|
	// Check that a stale cookie can't load an already cleared session
 | 
						|
	Context("when Clear is called on a persistent store", func() {
 | 
						|
		var resultCookies []*http.Cookie
 | 
						|
 | 
						|
		BeforeEach(func() {
 | 
						|
			req := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
			saveResp := httptest.NewRecorder()
 | 
						|
			err := in.ss().Save(saveResp, req, in.session)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			resultCookies = saveResp.Result().Cookies()
 | 
						|
			for _, c := range resultCookies {
 | 
						|
				in.request.AddCookie(c)
 | 
						|
			}
 | 
						|
			err = in.ss().Clear(in.response, in.request)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
		})
 | 
						|
 | 
						|
		Context("attempting to Load", func() {
 | 
						|
			var loadedAfterClear *sessionsapi.SessionState
 | 
						|
			var loadErr error
 | 
						|
 | 
						|
			BeforeEach(func() {
 | 
						|
				loadReq := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
				for _, c := range resultCookies {
 | 
						|
					loadReq.AddCookie(c)
 | 
						|
				}
 | 
						|
 | 
						|
				loadedAfterClear, loadErr = in.ss().Load(loadReq)
 | 
						|
			})
 | 
						|
 | 
						|
			It("returns an empty session", func() {
 | 
						|
				Expect(loadedAfterClear).To(BeNil())
 | 
						|
			})
 | 
						|
 | 
						|
			It("returns an error", func() {
 | 
						|
				Expect(loadErr).To(HaveOccurred())
 | 
						|
			})
 | 
						|
		})
 | 
						|
 | 
						|
		CheckCookieOptions(in)
 | 
						|
	})
 | 
						|
 | 
						|
	// Test TTLs and cleanup of persistent session storage
 | 
						|
	// For non-persistent we rely on the browser cookie lifecycle
 | 
						|
	Context("when Load is called on a persistent store", func() {
 | 
						|
		BeforeEach(func() {
 | 
						|
			req := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
			resp := httptest.NewRecorder()
 | 
						|
			err := in.ss().Save(resp, req, in.session)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			for _, cookie := range resp.Result().Cookies() {
 | 
						|
				in.request.AddCookie(cookie)
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		Context("after the refresh period, but before the cookie expire period", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				Expect(in.persistentFastForward(in.cookieOpts.Refresh + time.Minute)).To(Succeed())
 | 
						|
			})
 | 
						|
 | 
						|
			LoadSessionTests(in)
 | 
						|
		})
 | 
						|
 | 
						|
		Context("after the cookie expire period", func() {
 | 
						|
			var loadedSession *sessionsapi.SessionState
 | 
						|
			var err error
 | 
						|
 | 
						|
			BeforeEach(func() {
 | 
						|
				Expect(in.persistentFastForward(in.cookieOpts.Expire + time.Minute)).To(Succeed())
 | 
						|
 | 
						|
				loadedSession, err = in.ss().Load(in.request)
 | 
						|
				Expect(err).To(HaveOccurred())
 | 
						|
			})
 | 
						|
 | 
						|
			It("returns an error loading the session", func() {
 | 
						|
				Expect(err).To(HaveOccurred())
 | 
						|
			})
 | 
						|
 | 
						|
			It("returns an empty session", func() {
 | 
						|
				Expect(loadedSession).To(BeNil())
 | 
						|
			})
 | 
						|
		})
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func SessionStoreInterfaceTests(in *testInput) {
 | 
						|
	Context("when Save is called", func() {
 | 
						|
		Context("with no existing session", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				err := in.ss().Save(in.response, in.request, in.session)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
			})
 | 
						|
 | 
						|
			It("sets a `set-cookie` header in the response", func() {
 | 
						|
				Expect(in.response.Header().Get("set-cookie")).ToNot(BeEmpty())
 | 
						|
			})
 | 
						|
 | 
						|
			It("Ensures the session CreatedAt is not zero", func() {
 | 
						|
				Expect(in.session.CreatedAt.IsZero()).To(BeFalse())
 | 
						|
			})
 | 
						|
 | 
						|
			CheckCookieOptions(in)
 | 
						|
		})
 | 
						|
 | 
						|
		Context("with a broken session", func() {
 | 
						|
			BeforeEach(func() {
 | 
						|
				By("Using a valid cookie with a different providers session encoding")
 | 
						|
				broken := "BrokenSessionFromADifferentSessionImplementation"
 | 
						|
				value, err := encryption.SignedValue(in.cookieOpts.Secret, in.cookieOpts.Name, []byte(broken), time.Now())
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
				cookie := cookiesapi.MakeCookieFromOptions(in.request, in.cookieOpts.Name, value, in.cookieOpts, in.cookieOpts.Expire, time.Now())
 | 
						|
				in.request.AddCookie(cookie)
 | 
						|
 | 
						|
				err = in.ss().Save(in.response, in.request, in.session)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
			})
 | 
						|
 | 
						|
			It("sets a `set-cookie` header in the response", func() {
 | 
						|
				Expect(in.response.Header().Get("set-cookie")).ToNot(BeEmpty())
 | 
						|
			})
 | 
						|
 | 
						|
			It("Ensures the session CreatedAt is not zero", func() {
 | 
						|
				Expect(in.session.CreatedAt.IsZero()).To(BeFalse())
 | 
						|
			})
 | 
						|
 | 
						|
			CheckCookieOptions(in)
 | 
						|
		})
 | 
						|
 | 
						|
		Context("with an expired saved session", func() {
 | 
						|
			var err error
 | 
						|
			BeforeEach(func() {
 | 
						|
				By("saving a session")
 | 
						|
				req := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
				saveResp := httptest.NewRecorder()
 | 
						|
				err = in.ss().Save(saveResp, req, in.session)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
				By("and clearing the session")
 | 
						|
				for _, c := range saveResp.Result().Cookies() {
 | 
						|
					in.request.AddCookie(c)
 | 
						|
				}
 | 
						|
				clearResp := httptest.NewRecorder()
 | 
						|
				err = in.ss().Clear(clearResp, in.request)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
				By("then saving a request with the cleared session")
 | 
						|
				err = in.ss().Save(in.response, in.request, in.session)
 | 
						|
			})
 | 
						|
 | 
						|
			It("no error should occur", func() {
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
			})
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("when Clear is called", func() {
 | 
						|
		BeforeEach(func() {
 | 
						|
			req := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
			saveResp := httptest.NewRecorder()
 | 
						|
			err := in.ss().Save(saveResp, req, in.session)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			for _, c := range saveResp.Result().Cookies() {
 | 
						|
				in.request.AddCookie(c)
 | 
						|
			}
 | 
						|
			err = in.ss().Clear(in.response, in.request)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
		})
 | 
						|
 | 
						|
		It("sets a `set-cookie` header in the response", func() {
 | 
						|
			Expect(in.response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
 | 
						|
		})
 | 
						|
 | 
						|
		CheckCookieOptions(in)
 | 
						|
	})
 | 
						|
 | 
						|
	Context("when Load is called", func() {
 | 
						|
		BeforeEach(func() {
 | 
						|
			req := httptest.NewRequest("GET", "http://example.com/", nil)
 | 
						|
			resp := httptest.NewRecorder()
 | 
						|
			err := in.ss().Save(resp, req, in.session)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			for _, cookie := range resp.Result().Cookies() {
 | 
						|
				in.request.AddCookie(cookie)
 | 
						|
			}
 | 
						|
		})
 | 
						|
 | 
						|
		Context("before the refresh period", func() {
 | 
						|
			LoadSessionTests(in)
 | 
						|
		})
 | 
						|
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
func LoadSessionTests(in *testInput) {
 | 
						|
	var loadedSession *sessionsapi.SessionState
 | 
						|
	BeforeEach(func() {
 | 
						|
		var err error
 | 
						|
		loadedSession, err = in.ss().Load(in.request)
 | 
						|
		Expect(err).ToNot(HaveOccurred())
 | 
						|
	})
 | 
						|
 | 
						|
	It("loads a session equal to the original session", func() {
 | 
						|
		// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
 | 
						|
		l := *loadedSession
 | 
						|
		l.CreatedAt = nil
 | 
						|
		l.ExpiresOn = nil
 | 
						|
		s := *in.session
 | 
						|
		s.CreatedAt = nil
 | 
						|
		s.ExpiresOn = nil
 | 
						|
		Expect(l).To(Equal(s))
 | 
						|
 | 
						|
		// Compare time.Time separately
 | 
						|
		Expect(loadedSession.CreatedAt.Equal(*in.session.CreatedAt)).To(BeTrue())
 | 
						|
		Expect(loadedSession.ExpiresOn.Equal(*in.session.ExpiresOn)).To(BeTrue())
 | 
						|
 | 
						|
	})
 | 
						|
}
 |