160 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			160 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
package redis_test
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"encoding/base64"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/alicebob/miniredis/v2"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
 | 
						|
	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/redis"
 | 
						|
	. "github.com/onsi/ginkgo"
 | 
						|
	. "github.com/onsi/gomega"
 | 
						|
)
 | 
						|
 | 
						|
var _ = Describe("Redis Client Tests", func() {
 | 
						|
	Context("with basic client", func() {
 | 
						|
		RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions {
 | 
						|
			return options.RedisStoreOptions{
 | 
						|
				ConnectionURL: "redis://" + mr.Addr(),
 | 
						|
			}
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("with cluster client", func() {
 | 
						|
		RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions {
 | 
						|
			return options.RedisStoreOptions{
 | 
						|
				ClusterConnectionURLs: []string{"redis://" + mr.Addr()},
 | 
						|
				UseCluster:            true,
 | 
						|
			}
 | 
						|
		})
 | 
						|
	})
 | 
						|
})
 | 
						|
 | 
						|
type getOptsFunc func(mr *miniredis.Miniredis) options.RedisStoreOptions
 | 
						|
 | 
						|
func RunClientTests(getOptsFunc getOptsFunc) {
 | 
						|
	var mr *miniredis.Miniredis
 | 
						|
	var client redis.Client
 | 
						|
	var err error
 | 
						|
	var key string
 | 
						|
	var ctx context.Context
 | 
						|
 | 
						|
	BeforeEach(func() {
 | 
						|
		mr, err = miniredis.Run()
 | 
						|
		Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
		client, err = redis.NewRedisClient(getOptsFunc(mr))
 | 
						|
		Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
		nonce, err := encryption.Nonce(32)
 | 
						|
		Expect(err).ToNot(HaveOccurred())
 | 
						|
		key = base64.RawURLEncoding.EncodeToString(nonce)
 | 
						|
 | 
						|
		ctx = context.Background()
 | 
						|
	})
 | 
						|
 | 
						|
	AfterEach(func() {
 | 
						|
		if mr != nil {
 | 
						|
			mr.Close()
 | 
						|
			mr = nil
 | 
						|
		}
 | 
						|
	})
 | 
						|
 | 
						|
	Context("when Get is called", func() {
 | 
						|
		expectedValue := []byte("value")
 | 
						|
 | 
						|
		BeforeEach(func() {
 | 
						|
			client.Set(context.Background(), key, expectedValue, time.Duration(1*time.Minute))
 | 
						|
		})
 | 
						|
 | 
						|
		It("returns the saved value", func() {
 | 
						|
			value, err := client.Get(ctx, key)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
			Expect(value).To(Equal(value))
 | 
						|
		})
 | 
						|
 | 
						|
		It("does not return expired values", func() {
 | 
						|
			mr.FastForward(5 * time.Minute)
 | 
						|
 | 
						|
			_, err = client.Get(ctx, key)
 | 
						|
			Expect(err).To(HaveOccurred())
 | 
						|
		})
 | 
						|
 | 
						|
		It("returns an error if value does not exist", func() {
 | 
						|
			_, err = client.Get(ctx, "does-not-exists")
 | 
						|
			Expect(err).To(HaveOccurred())
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("using Lock", func() {
 | 
						|
		It("maintains the lock", func() {
 | 
						|
			lock := client.Lock(key)
 | 
						|
 | 
						|
			err = lock.Obtain(ctx, 1*time.Minute)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			isLocked, err := lock.Peek(ctx)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
			Expect(isLocked).To(BeTrue())
 | 
						|
 | 
						|
			err = lock.Release(ctx)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
		})
 | 
						|
 | 
						|
		It("reflects non-locked instance", func() {
 | 
						|
			lock := client.Lock(key)
 | 
						|
 | 
						|
			isLocked, err := lock.Peek(ctx)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
			Expect(isLocked).To(BeFalse())
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("when Set is called", func() {
 | 
						|
		expectedValue := []byte("value")
 | 
						|
 | 
						|
		It("sets the expected value", func() {
 | 
						|
			err = client.Set(ctx, key, expectedValue, 1*time.Minute)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			value, err := client.Get(ctx, key)
 | 
						|
			Expect(value).To(Equal(expectedValue))
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("when Del is called", func() {
 | 
						|
		It("does not return an error when key exists", func() {
 | 
						|
			err = client.Set(ctx, key, []byte("dummy"), 1*time.Minute)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			err = client.Del(ctx, key)
 | 
						|
			Expect(err).ToNot(HaveOccurred())
 | 
						|
 | 
						|
			_, err = client.Get(ctx, key)
 | 
						|
			Expect(err).To(HaveOccurred())
 | 
						|
		})
 | 
						|
	})
 | 
						|
 | 
						|
	Context("when Ping is called", func() {
 | 
						|
		Context("when redis is up", func() {
 | 
						|
			It("does not return an error", func() {
 | 
						|
				err = client.Ping(ctx)
 | 
						|
				Expect(err).ToNot(HaveOccurred())
 | 
						|
			})
 | 
						|
		})
 | 
						|
 | 
						|
		Context("when redis is down", func() {
 | 
						|
			It("returns an error", func() {
 | 
						|
				mr.Close()
 | 
						|
				mr = nil
 | 
						|
 | 
						|
				err = client.Ping(ctx)
 | 
						|
				Expect(err).To(HaveOccurred())
 | 
						|
			})
 | 
						|
		})
 | 
						|
	})
 | 
						|
}
 |