From f753ec1ca5c42cd4c4e400afa1f76817638f3010 Mon Sep 17 00:00:00 2001 From: Kobi Meirson Date: Fri, 23 Dec 2022 11:08:12 +0200 Subject: [PATCH] feat: readiness check (#1839) * feat: readiness check * fix: no need for query param * docs: add a note * chore: move the readyness check to its own endpoint * docs(cr): add godoc Co-authored-by: Joel Speed --- CHANGELOG.md | 2 + .../kubernetes/oauth2-proxy-example-full.yaml | 4 +- contrib/oauth2-proxy_autocomplete.sh | 2 +- docs/docs/configuration/overview.md | 5 +- docs/docs/features/endpoints.md | 1 + oauthproxy.go | 6 +- pkg/apis/options/load_test.go | 1 + pkg/apis/options/logging.go | 2 +- pkg/apis/options/options.go | 3 + pkg/apis/sessions/interfaces.go | 1 + pkg/middleware/readynesscheck.go | 40 +++++ pkg/middleware/readynesscheck_test.go | 84 +++++++++ pkg/middleware/stored_session_test.go | 4 + pkg/sessions/cookie/session_store.go | 7 + pkg/sessions/persistence/interfaces.go | 1 + pkg/sessions/persistence/manager.go | 6 + pkg/sessions/redis/client.go | 9 + pkg/sessions/redis/client_test.go | 159 ++++++++++++++++++ pkg/sessions/redis/redis_store.go | 8 + pkg/sessions/redis/redis_store_test.go | 27 --- pkg/sessions/redis/redis_test.go | 35 ++++ pkg/sessions/tests/mock_store.go | 4 + pkg/sessions/tests/session_store_tests.go | 6 + 23 files changed, 382 insertions(+), 35 deletions(-) create mode 100644 pkg/middleware/readynesscheck.go create mode 100644 pkg/middleware/readynesscheck_test.go create mode 100644 pkg/sessions/redis/client_test.go create mode 100644 pkg/sessions/redis/redis_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index f9900432..59bafc91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,8 +12,10 @@ - [#1882](https://github.com/oauth2-proxy/oauth2-proxy/pull/1882) Make `htpasswd.GetUsers` racecondition safe - [#1883](https://github.com/oauth2-proxy/oauth2-proxy/pull/1883) Ensure v8 manifest variant is set on docker images - [#1906](https://github.com/oauth2-proxy/oauth2-proxy/pull/1906) Fix PKCE code verifier generation to never use UTF-8 characters +- [#1839](https://github.com/oauth2-proxy/oauth2-proxy/pull/1839) Add readiness checks for deeper health checks (@kobim) - [#1927](https://github.com/oauth2-proxy/oauth2-proxy/pull/1927) Fix default scope settings for none oidc providers + # V7.4.0 ## Release Highlights diff --git a/contrib/local-environment/kubernetes/oauth2-proxy-example-full.yaml b/contrib/local-environment/kubernetes/oauth2-proxy-example-full.yaml index 8ccc5977..197f81a8 100644 --- a/contrib/local-environment/kubernetes/oauth2-proxy-example-full.yaml +++ b/contrib/local-environment/kubernetes/oauth2-proxy-example-full.yaml @@ -449,11 +449,11 @@ spec: timeoutSeconds: 1 readinessProbe: httpGet: - path: /ping + path: /ready port: http scheme: HTTP initialDelaySeconds: 0 - timeoutSeconds: 1 + timeoutSeconds: 5 successThreshold: 1 periodSeconds: 10 resources: diff --git a/contrib/oauth2-proxy_autocomplete.sh b/contrib/oauth2-proxy_autocomplete.sh index 5c11738b..0dd8d304 100644 --- a/contrib/oauth2-proxy_autocomplete.sh +++ b/contrib/oauth2-proxy_autocomplete.sh @@ -24,7 +24,7 @@ _oauth2_proxy() { COMPREPLY=( $(compgen -W 'X-Real-IP X-Forwarded-For X-ProxyUser-IP' -- ${cur}) ) return 0 ;; - --@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) + --@(http-address|https-address|redirect-url|upstream|basic-auth-password|skip-auth-regex|flush-interval|extra-jwt-issuers|email-domain|whitelist-domain|trusted-ip|keycloak-group|azure-tenant|bitbucket-team|bitbucket-repository|github-org|github-team|github-repo|github-token|gitlab-group|github-user|google-group|google-admin-email|google-service-account-json|client-id|client_secret|banner|footer|proxy-prefix|ping-path|ready-path|cookie-name|cookie-secret|cookie-domain|cookie-path|cookie-expire|cookie-refresh|cookie-samesite|redist-sentinel-master-name|redist-sentinel-connection-urls|redist-cluster-connection-urls|logging-max-size|logging-max-age|logging-max-backups|standard-logging-format|request-logging-format|exclude-logging-paths|auth-logging-format|oidc-issuer-url|oidc-jwks-url|login-url|redeem-url|profile-url|resource|validate-url|scope|approval-prompt|signature-key|acr-values|jwt-key|pubjwk-url|force-json-errors)) return 0 ;; esac diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index 10e0d810..e36544e0 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -155,6 +155,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--provider-display-name` | string | Override the provider's name with the given string; used for the sign-in page | (depends on provider) | | `--ping-path` | string | the ping endpoint that can be used for basic health checks | `"/ping"` | | `--ping-user-agent` | string | a User-Agent that can be used for basic health checks | `""` (don't check user agent) | +| `--ready-path` | string | the ready endpoint that can be used for deep health checks | `"/ready"` | | `--metrics-address` | string | the address prometheus metrics will be scraped from | `""` | | `--proxy-prefix` | string | the url root path that this proxy should be nested under (e.g. /`/sign_in`) | `"/oauth2"` | | `--proxy-websockets` | bool | enables WebSocket proxying | true | @@ -184,7 +185,7 @@ An example [oauth2-proxy.cfg](https://github.com/oauth2-proxy/oauth2-proxy/blob/ | `--set-basic-auth` | bool | set HTTP Basic Auth information in response (useful in Nginx auth_request mode) | false | | `--show-debug-on-error` | bool | show detailed error information on error pages (WARNING: this may contain sensitive information - do not use in production) | false | | `--signature-key` | string | GAP-Signature request signature key (algorithm:secretkey) | | -| `--silence-ping-logging` | bool | disable logging of requests to ping endpoint | false | +| `--silence-ping-logging` | bool | disable logging of requests to ping & ready endpoints | false | | `--skip-auth-preflight` | bool | will skip authentication for OPTIONS requests | false | | `--skip-auth-regex` | string \| list | (DEPRECATED for `--skip-auth-route`) bypass authentication for requests paths that match (may be given multiple times) | | | `--skip-auth-route` | string \| list | bypass authentication for requests that match the method & path. Format: method=path_regex OR method!=path_regex. For all methods: path_regex OR !=path_regex | | @@ -246,7 +247,7 @@ There are three different types of logging: standard, authentication, and HTTP r Each type of logging has its own configurable format and variables. By default these formats are similar to the Apache Combined Log. -Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) can be disabled with `--silence-ping-logging` reducing log volume. This flag appends the `--ping-path` to `--exclude-logging-paths`. +Logging of requests to the `/ping` endpoint (or using `--ping-user-agent`) and the `/ready` endpoint can be disabled with `--silence-ping-logging` reducing log volume. ### Auth Log Format Authentication logs are logs which are guaranteed to contain a username or email address of a user attempting to authenticate. These logs are output by default in the below format: diff --git a/docs/docs/features/endpoints.md b/docs/docs/features/endpoints.md index bacb5439..ba3210bb 100644 --- a/docs/docs/features/endpoints.md +++ b/docs/docs/features/endpoints.md @@ -7,6 +7,7 @@ OAuth2 Proxy responds directly to the following endpoints. All other endpoints w - /robots.txt - returns a 200 OK response that disallows all User-agents from all paths; see [robotstxt.org](http://www.robotstxt.org/) for more info - /ping - returns a 200 OK response, which is intended for use with health checks +- /ready - returns a 200 OK response if all the underlying connections (e.g., Redis store) are connected - /metrics - Metrics endpoint for Prometheus to scrape, serve on the address specified by `--metrics-address`, disabled by default - /oauth2/sign_in - the login page, which also doubles as a sign out page (it clears cookies) - /oauth2/sign_out - this URL is used to clear the session cookie diff --git a/oauthproxy.go b/oauthproxy.go index 5b50bf5e..a0d4bf7b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -185,7 +185,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } - preAuthChain, err := buildPreAuthChain(opts) + preAuthChain, err := buildPreAuthChain(opts, sessionStore) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } @@ -327,7 +327,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // buildPreAuthChain constructs a chain that should process every request before // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. -func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { +func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) if opts.ForceHTTPS { @@ -351,12 +351,14 @@ func buildPreAuthChain(opts *options.Options) (alice.Chain, error) { if opts.Logging.SilencePing { chain = chain.Append( middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), + middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), middleware.NewRequestLogger(), ) } else { chain = chain.Append( middleware.NewRequestLogger(), middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), + middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), ) } diff --git a/pkg/apis/options/load_test.go b/pkg/apis/options/load_test.go index cface514..5939b858 100644 --- a/pkg/apis/options/load_test.go +++ b/pkg/apis/options/load_test.go @@ -49,6 +49,7 @@ var _ = Describe("Load", func() { Options: Options{ ProxyPrefix: "/oauth2", PingPath: "/ping", + ReadyPath: "/ready", RealClientIPHeader: "X-Real-IP", ForceHTTPS: false, Cookie: cookieDefaults(), diff --git a/pkg/apis/options/logging.go b/pkg/apis/options/logging.go index 3ed79127..dfffd0fa 100644 --- a/pkg/apis/options/logging.go +++ b/pkg/apis/options/logging.go @@ -43,7 +43,7 @@ func loggingFlagSet() *pflag.FlagSet { flagSet.StringSlice("exclude-logging-path", []string{}, "Exclude logging requests to paths (eg: '/path1,/path2,/path3')") flagSet.Bool("logging-local-time", true, "If the time in log files and backup filenames are local or UTC time") - flagSet.Bool("silence-ping-logging", false, "Disable logging of requests to ping endpoint") + flagSet.Bool("silence-ping-logging", false, "Disable logging of requests to ping & ready endpoints") flagSet.String("request-id-header", "X-Request-Id", "Request header to use as the request ID") flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index c65f1244..0af8df3f 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -21,6 +21,7 @@ type Options struct { ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy_prefix"` PingPath string `flag:"ping-path" cfg:"ping_path"` PingUserAgent string `flag:"ping-user-agent" cfg:"ping_user_agent"` + ReadyPath string `flag:"ready-path" cfg:"ready_path"` ReverseProxy bool `flag:"reverse-proxy" cfg:"reverse_proxy"` RealClientIPHeader string `flag:"real-client-ip-header" cfg:"real_client_ip_header"` TrustedIPs []string `flag:"trusted-ip" cfg:"trusted_ips"` @@ -96,6 +97,7 @@ func NewOptions() *Options { ProxyPrefix: "/oauth2", Providers: providerDefaults(), PingPath: "/ping", + ReadyPath: "/ready", RealClientIPHeader: "X-Real-IP", ForceHTTPS: false, Cookie: cookieDefaults(), @@ -133,6 +135,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") flagSet.String("ping-path", "/ping", "the ping endpoint that can be used for basic health checks") flagSet.String("ping-user-agent", "", "special User-Agent that will be used for basic health checks") + flagSet.String("ready-path", "/ready", "the ready endpoint that can be used for deep health checks") flagSet.String("session-store-type", "cookie", "the session storage provider to use") flagSet.Bool("session-cookie-minimal", false, "strip OAuth tokens from cookie session stores if they aren't needed (cookie session store only)") flagSet.String("redis-connection-url", "", "URL of redis server for redis session storage (eg: redis://HOST[:PORT])") diff --git a/pkg/apis/sessions/interfaces.go b/pkg/apis/sessions/interfaces.go index 6569f2f3..97c364cf 100644 --- a/pkg/apis/sessions/interfaces.go +++ b/pkg/apis/sessions/interfaces.go @@ -12,6 +12,7 @@ type SessionStore interface { Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error Load(req *http.Request) (*SessionState, error) Clear(rw http.ResponseWriter, req *http.Request) error + VerifyConnection(ctx context.Context) error } var ErrLockNotObtained = errors.New("lock: not obtained") diff --git a/pkg/middleware/readynesscheck.go b/pkg/middleware/readynesscheck.go new file mode 100644 index 00000000..aee871e1 --- /dev/null +++ b/pkg/middleware/readynesscheck.go @@ -0,0 +1,40 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + + "github.com/justinas/alice" +) + +// Verifiable an interface for an object that has a connection to external +// data source and exports a function to validate that connection +type Verifiable interface { + VerifyConnection(context.Context) error +} + +// NewReadynessCheck returns a middleware that performs deep health checks +// (verifies the connection to any underlying store) on a specific `path` +func NewReadynessCheck(path string, verifiable Verifiable) alice.Constructor { + return func(next http.Handler) http.Handler { + return readynessCheck(path, verifiable, next) + } +} + +func readynessCheck(path string, verifiable Verifiable, next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if path != "" && req.URL.EscapedPath() == path { + if err := verifiable.VerifyConnection(req.Context()); err != nil { + rw.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(rw, "error: %v", err) + return + } + rw.WriteHeader(http.StatusOK) + fmt.Fprintf(rw, "OK") + return + } + + next.ServeHTTP(rw, req) + }) +} diff --git a/pkg/middleware/readynesscheck_test.go b/pkg/middleware/readynesscheck_test.go new file mode 100644 index 00000000..53052b88 --- /dev/null +++ b/pkg/middleware/readynesscheck_test.go @@ -0,0 +1,84 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("ReadynessCheck suite", func() { + type requestTableInput struct { + readyPath string + healthVerifiable Verifiable + requestString string + expectedStatus int + expectedBody string + } + + DescribeTable("when serving a request", + func(in *requestTableInput) { + req := httptest.NewRequest("", in.requestString, nil) + + rw := httptest.NewRecorder() + + handler := NewReadynessCheck(in.readyPath, in.healthVerifiable)(http.NotFoundHandler()) + handler.ServeHTTP(rw, req) + + Expect(rw.Code).To(Equal(in.expectedStatus)) + Expect(rw.Body.String()).To(Equal(in.expectedBody)) + }, + Entry("when requesting the readyness check path", &requestTableInput{ + readyPath: "/ready", + healthVerifiable: &fakeVerifiable{nil}, + requestString: "http://example.com/ready", + expectedStatus: 200, + expectedBody: "OK", + }), + Entry("when requesting a different path", &requestTableInput{ + readyPath: "/ready", + healthVerifiable: &fakeVerifiable{nil}, + requestString: "http://example.com/different", + expectedStatus: 404, + expectedBody: "404 page not found\n", + }), + Entry("when a blank string is configured as a readyness check path and the request has no specific path", &requestTableInput{ + readyPath: "", + healthVerifiable: &fakeVerifiable{nil}, + requestString: "http://example.com", + expectedStatus: 404, + expectedBody: "404 page not found\n", + }), + Entry("with full health check and without an underlying error", &requestTableInput{ + readyPath: "/ready", + healthVerifiable: &fakeVerifiable{nil}, + requestString: "http://example.com/ready", + expectedStatus: 200, + expectedBody: "OK", + }), + Entry("with full health check and with an underlying error", &requestTableInput{ + readyPath: "/ready", + healthVerifiable: &fakeVerifiable{func(ctx context.Context) error { return errors.New("failed to check") }}, + requestString: "http://example.com/ready", + expectedStatus: 500, + expectedBody: "error: failed to check", + }), + ) +}) + +type fakeVerifiable struct { + mock func(context.Context) error +} + +func (v *fakeVerifiable) VerifyConnection(ctx context.Context) error { + if v.mock != nil { + return v.mock(ctx) + } + return nil +} + +var _ Verifiable = (*fakeVerifiable)(nil) diff --git a/pkg/middleware/stored_session_test.go b/pkg/middleware/stored_session_test.go index c462278f..4dd7292f 100644 --- a/pkg/middleware/stored_session_test.go +++ b/pkg/middleware/stored_session_test.go @@ -793,3 +793,7 @@ func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro } return nil } + +func (f *fakeSessionStore) VerifyConnection(_ context.Context) error { + return nil +} diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 035ae1f0..f2f4045f 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -1,6 +1,7 @@ package cookie import ( + "context" "errors" "fmt" "net/http" @@ -82,6 +83,12 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { return nil } +// VerifyConnection always return no-error, as there's no connection +// in this store +func (s *SessionStore) VerifyConnection(_ context.Context) error { + return nil +} + // cookieForSession serializes a session state for storage in a cookie func (s *SessionStore) cookieForSession(ss *sessions.SessionState) ([]byte, error) { if s.Minimal && (ss.AccessToken != "" || ss.IDToken != "" || ss.RefreshToken != "") { diff --git a/pkg/sessions/persistence/interfaces.go b/pkg/sessions/persistence/interfaces.go index b1f50da5..5bab9912 100644 --- a/pkg/sessions/persistence/interfaces.go +++ b/pkg/sessions/persistence/interfaces.go @@ -15,4 +15,5 @@ type Store interface { Load(context.Context, string) ([]byte, error) Clear(context.Context, string) error Lock(key string) sessions.Lock + VerifyConnection(context.Context) error } diff --git a/pkg/sessions/persistence/manager.go b/pkg/sessions/persistence/manager.go index 3215b257..9652f015 100644 --- a/pkg/sessions/persistence/manager.go +++ b/pkg/sessions/persistence/manager.go @@ -1,6 +1,7 @@ package persistence import ( + "context" "fmt" "net/http" "time" @@ -90,3 +91,8 @@ func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error { return m.Store.Clear(req.Context(), key) }) } + +// VerifyConnection validates the underlying store is ready and connected +func (m *Manager) VerifyConnection(ctx context.Context) error { + return m.Store.VerifyConnection(ctx) +} diff --git a/pkg/sessions/redis/client.go b/pkg/sessions/redis/client.go index ac1c540e..01360bdb 100644 --- a/pkg/sessions/redis/client.go +++ b/pkg/sessions/redis/client.go @@ -14,6 +14,7 @@ type Client interface { Lock(key string) sessions.Lock Set(ctx context.Context, key string, value []byte, expiration time.Duration) error Del(ctx context.Context, key string) error + Ping(ctx context.Context) error } var _ Client = (*client)(nil) @@ -44,6 +45,10 @@ func (c *client) Lock(key string) sessions.Lock { return NewLock(c.Client, key) } +func (c *client) Ping(ctx context.Context) error { + return c.Client.Ping(ctx).Err() +} + var _ Client = (*clusterClient)(nil) type clusterClient struct { @@ -71,3 +76,7 @@ func (c *clusterClient) Del(ctx context.Context, key string) error { func (c *clusterClient) Lock(key string) sessions.Lock { return NewLock(c.ClusterClient, key) } + +func (c *clusterClient) Ping(ctx context.Context) error { + return c.ClusterClient.Ping(ctx).Err() +} diff --git a/pkg/sessions/redis/client_test.go b/pkg/sessions/redis/client_test.go new file mode 100644 index 00000000..580947ca --- /dev/null +++ b/pkg/sessions/redis/client_test.go @@ -0,0 +1,159 @@ +package redis_test + +import ( + "context" + "encoding/base64" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/redis" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Redis Client Tests", func() { + Context("with basic client", func() { + RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions { + return options.RedisStoreOptions{ + ConnectionURL: "redis://" + mr.Addr(), + } + }) + }) + + Context("with cluster client", func() { + RunClientTests(func(mr *miniredis.Miniredis) options.RedisStoreOptions { + return options.RedisStoreOptions{ + ClusterConnectionURLs: []string{"redis://" + mr.Addr()}, + UseCluster: true, + } + }) + }) +}) + +type getOptsFunc func(mr *miniredis.Miniredis) options.RedisStoreOptions + +func RunClientTests(getOptsFunc getOptsFunc) { + var mr *miniredis.Miniredis + var client redis.Client + var err error + var key string + var ctx context.Context + + BeforeEach(func() { + mr, err = miniredis.Run() + Expect(err).ToNot(HaveOccurred()) + + client, err = redis.NewRedisClient(getOptsFunc(mr)) + Expect(err).ToNot(HaveOccurred()) + + nonce, err := encryption.Nonce(32) + Expect(err).ToNot(HaveOccurred()) + key = base64.RawURLEncoding.EncodeToString(nonce) + + ctx = context.Background() + }) + + AfterEach(func() { + if mr != nil { + mr.Close() + mr = nil + } + }) + + Context("when Get is called", func() { + expectedValue := []byte("value") + + BeforeEach(func() { + client.Set(context.Background(), key, expectedValue, time.Duration(1*time.Minute)) + }) + + It("returns the saved value", func() { + value, err := client.Get(ctx, key) + Expect(err).ToNot(HaveOccurred()) + Expect(value).To(Equal(value)) + }) + + It("does not return expired values", func() { + mr.FastForward(5 * time.Minute) + + _, err = client.Get(ctx, key) + Expect(err).To(HaveOccurred()) + }) + + It("returns an error if value does not exist", func() { + _, err = client.Get(ctx, "does-not-exists") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("using Lock", func() { + It("maintains the lock", func() { + lock := client.Lock(key) + + err = lock.Obtain(ctx, 1*time.Minute) + Expect(err).ToNot(HaveOccurred()) + + isLocked, err := lock.Peek(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(isLocked).To(BeTrue()) + + err = lock.Release(ctx) + Expect(err).ToNot(HaveOccurred()) + }) + + It("reflects non-locked instance", func() { + lock := client.Lock(key) + + isLocked, err := lock.Peek(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(isLocked).To(BeFalse()) + }) + }) + + Context("when Set is called", func() { + expectedValue := []byte("value") + + It("sets the expected value", func() { + err = client.Set(ctx, key, expectedValue, 1*time.Minute) + Expect(err).ToNot(HaveOccurred()) + + value, err := client.Get(ctx, key) + Expect(value).To(Equal(expectedValue)) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("when Del is called", func() { + It("does not return an error when key exists", func() { + err = client.Set(ctx, key, []byte("dummy"), 1*time.Minute) + Expect(err).ToNot(HaveOccurred()) + + err = client.Del(ctx, key) + Expect(err).ToNot(HaveOccurred()) + + _, err = client.Get(ctx, key) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when Ping is called", func() { + Context("when redis is up", func() { + It("does not return an error", func() { + err = client.Ping(ctx) + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("when redis is down", func() { + It("returns an error", func() { + mr.Close() + mr = nil + + err = client.Ping(ctx) + Expect(err).To(HaveOccurred()) + }) + }) + }) +} diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index fccea249..e8f444b6 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -70,6 +70,12 @@ func (store *SessionStore) Lock(key string) sessions.Lock { return store.Client.Lock(key) } +// VerifyConnection verifies the redis connection is valid and the +// server is responsive +func (store *SessionStore) VerifyConnection(ctx context.Context) error { + return store.Client.Ping(ctx) +} + // NewRedisClient makes a redis.Client (either standalone, sentinel aware, or // redis cluster) func NewRedisClient(opts options.RedisStoreOptions) (Client, error) { @@ -205,3 +211,5 @@ func parseRedisURLs(urls []string) ([]string, *redis.Options, error) { } return addrs, redisOptions, nil } + +var _ persistence.Store = (*SessionStore)(nil) diff --git a/pkg/sessions/redis/redis_store_test.go b/pkg/sessions/redis/redis_store_test.go index 14155043..9357b809 100644 --- a/pkg/sessions/redis/redis_store_test.go +++ b/pkg/sessions/redis/redis_store_test.go @@ -2,20 +2,15 @@ package redis import ( "bytes" - "context" "crypto/tls" "encoding/pem" - "log" "os" - "testing" "time" "github.com/Bose/minisentinel" "github.com/alicebob/miniredis/v2" - "github.com/go-redis/redis/v9" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/persistence" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/sessions/tests" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/util" @@ -25,28 +20,6 @@ import ( const redisPassword = "0123456789abcdefghijklmnopqrstuv" -// wrappedRedisLogger wraps a logger so that we can coerce the logger to -// fit the expected signature for go-redis logging -type wrappedRedisLogger struct { - *log.Logger -} - -func (l *wrappedRedisLogger) Printf(_ context.Context, format string, v ...interface{}) { - l.Logger.Printf(format, v...) -} - -func TestSessionStore(t *testing.T) { - logger.SetOutput(GinkgoWriter) - logger.SetErrOutput(GinkgoWriter) - - redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)} - redisLogger.SetOutput(GinkgoWriter) - redis.SetLogger(redisLogger) - - RegisterFailHandler(Fail) - RunSpecs(t, "Redis SessionStore") -} - var ( cert tls.Certificate caPath string diff --git a/pkg/sessions/redis/redis_test.go b/pkg/sessions/redis/redis_test.go new file mode 100644 index 00000000..5902c3c0 --- /dev/null +++ b/pkg/sessions/redis/redis_test.go @@ -0,0 +1,35 @@ +package redis_test + +import ( + "context" + "log" + "os" + "testing" + + "github.com/go-redis/redis/v9" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +// wrappedRedisLogger wraps a logger so that we can coerce the logger to +// fit the expected signature for go-redis logging +type wrappedRedisLogger struct { + *log.Logger +} + +func (l *wrappedRedisLogger) Printf(_ context.Context, format string, v ...interface{}) { + l.Logger.Printf(format, v...) +} + +func TestRedis(t *testing.T) { + logger.SetOutput(GinkgoWriter) + logger.SetErrOutput(GinkgoWriter) + + redisLogger := &wrappedRedisLogger{Logger: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile)} + redisLogger.SetOutput(GinkgoWriter) + redis.SetLogger(redisLogger) + + RegisterFailHandler(Fail) + RunSpecs(t, "Redis") +} diff --git a/pkg/sessions/tests/mock_store.go b/pkg/sessions/tests/mock_store.go index 0ea66145..c82d8c08 100644 --- a/pkg/sessions/tests/mock_store.go +++ b/pkg/sessions/tests/mock_store.go @@ -65,6 +65,10 @@ func (s *MockStore) Lock(key string) sessions.Lock { return lock } +func (s *MockStore) VerifyConnection(_ context.Context) error { + return nil +} + // FastForward simulates the flow of time to test expirations func (s *MockStore) FastForward(duration time.Duration) { for _, mockLock := range s.lockCache { diff --git a/pkg/sessions/tests/session_store_tests.go b/pkg/sessions/tests/session_store_tests.go index 80debebb..0416b24f 100644 --- a/pkg/sessions/tests/session_store_tests.go +++ b/pkg/sessions/tests/session_store_tests.go @@ -485,6 +485,12 @@ func SessionStoreInterfaceTests(in *testInput) { }) }) }) + + Context("when VerifyConnection is called", func() { + It("should return without an error", func() { + Expect(in.ss().VerifyConnection(in.request.Context())).ToNot(HaveOccurred()) + }) + }) } func LoadSessionTests(in *testInput) {