From 6fb29dc798c06d5d0d7a23d7c1b5c37d35fe423e Mon Sep 17 00:00:00 2001 From: Adel Salakh Date: Tue, 13 May 2025 19:10:07 +0200 Subject: [PATCH 1/4] add graceful shutdown Signed-off-by: Jan Larwig Signed-off-by: Adel Salakh --- main.go | 1 + oauthproxy.go | 23 +++++++++++++++-------- pkg/apis/options/options.go | 4 ++++ pkg/middleware/readynesscheck.go | 11 ++++++++--- pkg/proxyhttp/server.go | 15 ++++++++++++++- 5 files changed, 42 insertions(+), 12 deletions(-) diff --git a/main.go b/main.go index 42e8bab0..af30cbd7 100644 --- a/main.go +++ b/main.go @@ -54,6 +54,7 @@ func main() { } validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) + oauthproxy, err := NewOAuthProxy(opts, validator) if err != nil { logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) diff --git a/oauthproxy.go b/oauthproxy.go index d933d930..b282b8fa 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -113,6 +113,9 @@ type OAuthProxy struct { redirectValidator redirect.Validator appDirector redirect.AppDirector + cancelFunc context.CancelFunc + cancelCtx context.Context + encodeState bool } @@ -200,7 +203,9 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } - preAuthChain, err := buildPreAuthChain(opts, sessionStore) + cancelCtx, cancelFunc := context.WithCancel(context.Background()) + + preAuthChain, err := buildPreAuthChain(opts, sessionStore, cancelCtx) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } @@ -248,6 +253,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr redirectValidator: redirectValidator, appDirector: appDirector, encodeState: opts.EncodeState, + cancelFunc: cancelFunc, + cancelCtx: cancelCtx, } p.buildServeMux(opts.ProxyPrefix) @@ -265,17 +272,15 @@ func (p *OAuthProxy) Start() error { panic("server has not been initialised") } - ctx, cancel := context.WithCancel(context.Background()) - // Observe signals in background goroutine. go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) <-sigint - cancel() // cancel the context + p.cancelFunc() // cancel the context }() - return p.server.Start(ctx) + return p.server.Start(p.cancelCtx) } func (p *OAuthProxy) setupServer(opts *options.Options) error { @@ -284,6 +289,7 @@ func (p *OAuthProxy) setupServer(opts *options.Options) error { BindAddress: opts.Server.BindAddress, SecureBindAddress: opts.Server.SecureBindAddress, TLS: opts.Server.TLS, + ShutdownDuration: opts.ShutdownDuration, } // Option: AllowQuerySemicolons @@ -353,7 +359,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // buildPreAuthChain constructs a chain that should process every request before // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. -func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) { +func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore, ctx context.Context) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) if opts.ForceHTTPS { @@ -374,17 +380,18 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt // To silence logging of health checks, register the health check handler before // the logging handler + readinessCheck := middleware.NewReadynessCheck(opts.ReadyPath, sessionStore, ctx) if opts.Logging.SilencePing { chain = chain.Append( middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), - middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), + readinessCheck, middleware.NewRequestLogger(), ) } else { chain = chain.Append( middleware.NewRequestLogger(), middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), - middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), + readinessCheck, ) } diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index b57d5aed..38d48a37 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -3,6 +3,7 @@ package options import ( "crypto" "net/url" + "time" ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" @@ -29,6 +30,8 @@ type Options struct { RawRedirectURL string `flag:"redirect-url" cfg:"redirect_url"` RelativeRedirectURL bool `flag:"relative-redirect-url" cfg:"relative_redirect_url"` + ShutdownDuration time.Duration `flag:"shutdown-duration" cfg:"shutdown_duration"` + AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` EmailDomains []string `flag:"email-domain" cfg:"email_domains"` WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"` @@ -161,6 +164,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") + flagSet.Duration("shutdown-duration", 0, "Amount of time to continue serving traffic after receiving an exit signal with readiness endpoint set to false.") flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet()) diff --git a/pkg/middleware/readynesscheck.go b/pkg/middleware/readynesscheck.go index aee871e1..bf3e6373 100644 --- a/pkg/middleware/readynesscheck.go +++ b/pkg/middleware/readynesscheck.go @@ -16,14 +16,19 @@ type Verifiable interface { // NewReadynessCheck returns a middleware that performs deep health checks // (verifies the connection to any underlying store) on a specific `path` -func NewReadynessCheck(path string, verifiable Verifiable) alice.Constructor { +func NewReadynessCheck(path string, verifiable Verifiable, ctx context.Context) alice.Constructor { return func(next http.Handler) http.Handler { - return readynessCheck(path, verifiable, next) + return readynessCheck(path, verifiable, next, ctx) } } -func readynessCheck(path string, verifiable Verifiable, next http.Handler) http.Handler { +func readynessCheck(path string, verifiable Verifiable, next http.Handler, ctx context.Context) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if ctx.Err() != nil { + rw.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(rw, "Shutting down") + return + } if path != "" && req.URL.EscapedPath() == path { if err := verifiable.VerifyConnection(req.Context()); err != nil { rw.WriteHeader(http.StatusInternalServerError) diff --git a/pkg/proxyhttp/server.go b/pkg/proxyhttp/server.go index a0fc6054..5372d1c3 100644 --- a/pkg/proxyhttp/server.go +++ b/pkg/proxyhttp/server.go @@ -40,12 +40,16 @@ type Opts struct { // Let testing infrastructure circumvent parsing file descriptors fdFiles []*os.File + + // Graceful shutdown duration + ShutdownDuration time.Duration } // NewServer creates a new Server from the options given. func NewServer(opts Opts) (Server, error) { s := &server{ - handler: opts.Handler, + handler: opts.Handler, + shutdownDuration: opts.ShutdownDuration, } if len(opts.fdFiles) > 0 { @@ -71,6 +75,9 @@ type server struct { // ensure activation.Files are called once fdFiles []*os.File + + // Graceful shutdown duration + shutdownDuration time.Duration } // setupListener sets the server listener if the HTTP server is enabled. @@ -214,10 +221,16 @@ func (s *server) startServer(ctx context.Context, listener net.Listener) error { g.Go(func() error { <-groupCtx.Done() + logger.Printf("Context canceled. Waiting %s before shutting down the listeners.", s.shutdownDuration) + + time.Sleep(s.shutdownDuration) + + logger.Printf("Shutting down listener.") if err := srv.Shutdown(context.Background()); err != nil { return fmt.Errorf("error shutting down server: %v", err) } + return nil }) From ce56ff135e18802d898037c097aff473176f5418 Mon Sep 17 00:00:00 2001 From: Adel Salakh Date: Thu, 15 May 2025 00:31:46 +0200 Subject: [PATCH 2/4] fix readiness check test Signed-off-by: Jan Larwig Signed-off-by: Adel Salakh --- pkg/middleware/readynesscheck_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/middleware/readynesscheck_test.go b/pkg/middleware/readynesscheck_test.go index 618c27b8..a00b8d88 100644 --- a/pkg/middleware/readynesscheck_test.go +++ b/pkg/middleware/readynesscheck_test.go @@ -25,7 +25,9 @@ var _ = Describe("ReadynessCheck suite", func() { rw := httptest.NewRecorder() - handler := NewReadynessCheck(in.readyPath, in.healthVerifiable)(http.NotFoundHandler()) + ctx := context.Background() + + handler := NewReadynessCheck(in.readyPath, in.healthVerifiable, ctx)(http.NotFoundHandler()) handler.ServeHTTP(rw, req) Expect(rw.Code).To(Equal(in.expectedStatus)) From 225ace3a2cdec8c2710d6ea0312344787f8f0dc0 Mon Sep 17 00:00:00 2001 From: Adel Salakh Date: Thu, 15 May 2025 01:03:08 +0200 Subject: [PATCH 3/4] refactor Signed-off-by: Jan Larwig Signed-off-by: Adel Salakh --- main.go | 17 ++++++++-- oauthproxy.go | 30 ++++------------- oauthproxy_test.go | 48 +++++++++++++-------------- pkg/middleware/readynesscheck.go | 6 ++-- pkg/middleware/readynesscheck_test.go | 2 +- 5 files changed, 49 insertions(+), 54 deletions(-) diff --git a/main.go b/main.go index af30cbd7..3f0b80b4 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,12 @@ package main import ( + "context" "fmt" "os" + "os/signal" "runtime" + "syscall" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" @@ -55,12 +58,22 @@ func main() { validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) - oauthproxy, err := NewOAuthProxy(opts, validator) + ctx, cancelFunc := context.WithCancel(context.Background()) + + // Observe signals in background goroutine. + go func() { + sigint := make(chan os.Signal, 1) + signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) + <-sigint + cancelFunc() // cancel the context + }() + + oauthproxy, err := NewOAuthProxy(ctx, opts, validator) if err != nil { logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) } - if err := oauthproxy.Start(); err != nil { + if err := oauthproxy.Start(ctx); err != nil { logger.Fatalf("ERROR: Failed to start OAuth2 Proxy: %v", err) } } diff --git a/oauthproxy.go b/oauthproxy.go index b282b8fa..827cc8d8 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -10,11 +10,8 @@ import ( "net" "net/http" "net/url" - "os" - "os/signal" "regexp" "strings" - "syscall" "time" "github.com/gorilla/mux" @@ -113,14 +110,11 @@ type OAuthProxy struct { redirectValidator redirect.Validator appDirector redirect.AppDirector - cancelFunc context.CancelFunc - cancelCtx context.Context - encodeState bool } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided -func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { +func NewOAuthProxy(ctx context.Context, opts *options.Options, validator func(string) bool) (*OAuthProxy, error) { sessionStore, err := sessions.NewSessionStore(&opts.Session, &opts.Cookie) if err != nil { return nil, fmt.Errorf("error initialising session store: %v", err) @@ -203,9 +197,7 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } - cancelCtx, cancelFunc := context.WithCancel(context.Background()) - - preAuthChain, err := buildPreAuthChain(opts, sessionStore, cancelCtx) + preAuthChain, err := buildPreAuthChain(ctx, opts, sessionStore) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } @@ -253,8 +245,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr redirectValidator: redirectValidator, appDirector: appDirector, encodeState: opts.EncodeState, - cancelFunc: cancelFunc, - cancelCtx: cancelCtx, } p.buildServeMux(opts.ProxyPrefix) @@ -265,22 +255,14 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return p, nil } -func (p *OAuthProxy) Start() error { +func (p *OAuthProxy) Start(ctx context.Context) error { if p.server == nil { // We have to call setupServer before Start is called. // If this doesn't happen it's a programming error. panic("server has not been initialised") } - // Observe signals in background goroutine. - go func() { - sigint := make(chan os.Signal, 1) - signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) - <-sigint - p.cancelFunc() // cancel the context - }() - - return p.server.Start(p.cancelCtx) + return p.server.Start(ctx) } func (p *OAuthProxy) setupServer(opts *options.Options) error { @@ -359,7 +341,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // buildPreAuthChain constructs a chain that should process every request before // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. -func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore, ctx context.Context) (alice.Chain, error) { +func buildPreAuthChain(ctx context.Context, opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) if opts.ForceHTTPS { @@ -380,7 +362,7 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt // To silence logging of health checks, register the health check handler before // the logging handler - readinessCheck := middleware.NewReadynessCheck(opts.ReadyPath, sessionStore, ctx) + readinessCheck := middleware.NewReadynessCheck(ctx, opts.ReadyPath, sessionStore) if opts.Logging.SilencePing { chain = chain.Append( middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), diff --git a/oauthproxy_test.go b/oauthproxy_test.go index ccabdbbd..dba926f7 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -48,7 +48,7 @@ func TestRobotsTxt(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } @@ -109,7 +109,7 @@ func Test_redeemCode(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } @@ -161,7 +161,7 @@ func Test_enrichSession(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) if err != nil { t.Fatal(err) } @@ -230,7 +230,7 @@ func TestBasicAuthPassword(t *testing.T) { providerURL, _ := url.Parse(providerServer.URL) const emailAddress = "john.doe@example.com" - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return email == emailAddress }) if err != nil { @@ -292,7 +292,7 @@ func TestPassGroupsHeadersWithGroups(t *testing.T) { CreatedAt: &created, } - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return email == emailAddress }) assert.NoError(t, err) @@ -388,7 +388,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) (*PassAccessTokenTe testProvider := NewTestProvider(providerURL, emailAddress) testProvider.ValidToken = opts.ValidToken - patt.proxy, err = NewOAuthProxy(patt.opts, func(email string) bool { + patt.proxy, err = NewOAuthProxy(context.Background(), patt.opts, func(email string) bool { return email == emailAddress }) patt.proxy.provider = testProvider @@ -593,7 +593,7 @@ func NewSignInPageTest(skipProvider bool) (*SignInPageTest, error) { return nil, err } - sipTest.proxy, err = NewOAuthProxy(sipTest.opts, func(email string) bool { + sipTest.proxy, err = NewOAuthProxy(context.Background(), sipTest.opts, func(email string) bool { return true }) if err != nil { @@ -629,7 +629,7 @@ func TestManualSignInStoresUserGroupsInTheSession(t *testing.T) { t.Fatal(err) } - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return true }) if err != nil { @@ -677,7 +677,7 @@ func ManualSignInWithCredentials(t *testing.T, user, pass string) int { t.Fatal(err) } - proxy, err := NewOAuthProxy(opts, func(email string) bool { + proxy, err := NewOAuthProxy(context.Background(), opts, func(email string) bool { return true }) if err != nil { @@ -827,7 +827,7 @@ func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifi return nil, err } - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1201,7 +1201,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { err := validation.Validate(pcTest.opts) assert.NoError(t, err) - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1294,7 +1294,7 @@ func TestAuthOnlyEndpointSetBasicAuthTrueRequestHeaders(t *testing.T) { err := validation.Validate(pcTest.opts) assert.NoError(t, err) - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1374,7 +1374,7 @@ func TestAuthOnlyEndpointSetBasicAuthFalseRequestHeaders(t *testing.T) { err := validation.Validate(pcTest.opts) assert.NoError(t, err) - pcTest.proxy, err = NewOAuthProxy(pcTest.opts, func(email string) bool { + pcTest.proxy, err = NewOAuthProxy(context.Background(), pcTest.opts, func(email string) bool { return pcTest.validateUser }) if err != nil { @@ -1430,7 +1430,7 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { upstreamURL, _ := url.Parse(upstreamServer.URL) - proxy, err := NewOAuthProxy(opts, func(string) bool { return false }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return false }) if err != nil { t.Fatal(err) } @@ -1555,7 +1555,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) er if err != nil { return err } - proxy, err := NewOAuthProxy(st.opts, func(email string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), st.opts, func(email string) bool { return true }) if err != nil { return err } @@ -1643,7 +1643,7 @@ func newAjaxRequestTest(forceJSONErrors bool) (*ajaxRequestTest, error) { return nil, err } - test.proxy, err = NewOAuthProxy(test.opts, func(email string) bool { + test.proxy, err = NewOAuthProxy(context.Background(), test.opts, func(email string) bool { return true }) if err != nil { @@ -1959,7 +1959,7 @@ func Test_noCacheHeaders(t *testing.T) { opts.SkipAuthRegex = []string{".*"} err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2234,7 +2234,7 @@ func TestTrustedIPs(t *testing.T) { err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(string) bool { return true }) assert.NoError(t, err) rw := httptest.NewRecorder() @@ -2479,7 +2479,7 @@ func TestApiRoutes(t *testing.T) { opts.SkipProviderButton = true err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2562,7 +2562,7 @@ func TestAllowedRequest(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2661,7 +2661,7 @@ func TestAllowedRequestWithForwardedUriHeader(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2760,7 +2760,7 @@ func TestAllowedRequestNegateWithoutMethod(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -2860,7 +2860,7 @@ func TestAllowedRequestNegateWithMethod(t *testing.T) { } err := validation.Validate(opts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(opts, func(_ string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), opts, func(_ string) bool { return true }) if err != nil { t.Fatal(err) } @@ -3524,7 +3524,7 @@ func TestGetOAuthRedirectURI(t *testing.T) { err := validation.Validate(baseOpts) assert.NoError(t, err) - proxy, err := NewOAuthProxy(tt.setupOpts(baseOpts), func(string) bool { return true }) + proxy, err := NewOAuthProxy(context.Background(), tt.setupOpts(baseOpts), func(string) bool { return true }) assert.NoError(t, err) assert.Equalf(t, tt.want, proxy.getOAuthRedirectURI(tt.req), "getOAuthRedirectURI(%v)", tt.req) diff --git a/pkg/middleware/readynesscheck.go b/pkg/middleware/readynesscheck.go index bf3e6373..6977cf2e 100644 --- a/pkg/middleware/readynesscheck.go +++ b/pkg/middleware/readynesscheck.go @@ -16,13 +16,13 @@ type Verifiable interface { // NewReadynessCheck returns a middleware that performs deep health checks // (verifies the connection to any underlying store) on a specific `path` -func NewReadynessCheck(path string, verifiable Verifiable, ctx context.Context) alice.Constructor { +func NewReadynessCheck(ctx context.Context, path string, verifiable Verifiable) alice.Constructor { return func(next http.Handler) http.Handler { - return readynessCheck(path, verifiable, next, ctx) + return readynessCheck(ctx, path, verifiable, next) } } -func readynessCheck(path string, verifiable Verifiable, next http.Handler, ctx context.Context) http.Handler { +func readynessCheck(ctx context.Context, path string, verifiable Verifiable, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if ctx.Err() != nil { rw.WriteHeader(http.StatusServiceUnavailable) diff --git a/pkg/middleware/readynesscheck_test.go b/pkg/middleware/readynesscheck_test.go index a00b8d88..952fc3f0 100644 --- a/pkg/middleware/readynesscheck_test.go +++ b/pkg/middleware/readynesscheck_test.go @@ -27,7 +27,7 @@ var _ = Describe("ReadynessCheck suite", func() { ctx := context.Background() - handler := NewReadynessCheck(in.readyPath, in.healthVerifiable, ctx)(http.NotFoundHandler()) + handler := NewReadynessCheck(ctx, in.readyPath, in.healthVerifiable)(http.NotFoundHandler()) handler.ServeHTTP(rw, req) Expect(rw.Code).To(Equal(in.expectedStatus)) From 20bb8e8b196a94bd5f5e55c606758c255b7a347d Mon Sep 17 00:00:00 2001 From: Adel Salakh Date: Tue, 27 May 2025 20:19:12 +0200 Subject: [PATCH 4/4] fix review comments Signed-off-by: Jan Larwig Signed-off-by: Adel Salakh --- CHANGELOG.md | 1 + docs/docs/configuration/alpha_config.md | 1 + docs/docs/configuration/overview.md | 2 ++ oauthproxy.go | 2 +- pkg/apis/options/legacy_options.go | 23 +++++++++++++---------- pkg/apis/options/options.go | 4 ---- pkg/apis/options/server.go | 7 +++++++ pkg/middleware/readynesscheck.go | 2 ++ pkg/proxyhttp/server.go | 2 -- 9 files changed, 27 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0a800f1..02b10442 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - [#3304](https://github.com/oauth2-proxy/oauth2-proxy/pull/3304) fix: added conditional so default is not always set and env vars are honored fixes 3303 (@pixeldrew) - [#3264](https://github.com/oauth2-proxy/oauth2-proxy/pull/3264) fix: more aggressively truncate logged access_token (@MartinNowak / @tuunit) - [#3267](https://github.com/oauth2-proxy/oauth2-proxy/pull/3267) fix: Session refresh handling in OIDC provider (@gysel) +- [#3068](https://github.com/oauth2-proxy/oauth2-proxy/pull/3068) feat: graceful shutdown to prevent errors when oauth2-proxy is load balanced (@adelsz) # V7.13.0 diff --git a/docs/docs/configuration/alpha_config.md b/docs/docs/configuration/alpha_config.md index 495bc206..efcacc7e 100644 --- a/docs/docs/configuration/alpha_config.md +++ b/docs/docs/configuration/alpha_config.md @@ -495,6 +495,7 @@ Server represents the configuration for an HTTP(S) server | `bindAddress` | _string_ | BindAddress is the address on which to serve traffic.
Leave blank or set to "-" to disable. | | `secureBindAddress` | _string_ | SecureBindAddress is the address on which to serve secure traffic.
Leave blank or set to "-" to disable. | | `tls` | _[TLS](#tls)_ | TLS contains the information for loading the certificate and key for the
secure traffic and further configuration for the TLS server. | +| `shutdownDuration` | _duration_ | Duration of time to continue serving traffic after receiving an exit signal.
During this time the readiness endpoint will be returning HTTP 503 errors.
Leave blank to disable. | ### TLS diff --git a/docs/docs/configuration/overview.md b/docs/docs/configuration/overview.md index b159df09..2bc795c9 100644 --- a/docs/docs/configuration/overview.md +++ b/docs/docs/configuration/overview.md @@ -235,6 +235,8 @@ Provider specific options can be found on their respective subpages. | flag: `--tls-key-file`
toml: `tls_key_file` | string | path to private key file | | | flag: `--tls-cipher-suite`
toml: `tls_cipher_suites` | string \| list | Restricts TLS cipher suites used by server to those listed (e.g. TLS_RSA_WITH_RC4_128_SHA) (may be given multiple times). If not specified, the default Go safe cipher list is used. List of valid cipher suites can be found in the [crypto/tls documentation](https://pkg.go.dev/crypto/tls#pkg-constants). | | | flag: `--tls-min-version`
toml: `tls_min_version` | string | minimum TLS version that is acceptable, either `"TLS1.2"` or `"TLS1.3"` | `"TLS1.2"` | +| flag: `--shutdown-duration`
toml: `shutdown_duration` | duration | Duration of time to continue serving traffic after receiving an exit signal. During this duration the readiness endpoint will be returning HTTP 503 errors. | `"0s"` | + ### Session Options diff --git a/oauthproxy.go b/oauthproxy.go index 827cc8d8..67d19104 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -271,7 +271,7 @@ func (p *OAuthProxy) setupServer(opts *options.Options) error { BindAddress: opts.Server.BindAddress, SecureBindAddress: opts.Server.SecureBindAddress, TLS: opts.Server.TLS, - ShutdownDuration: opts.ShutdownDuration, + ShutdownDuration: opts.Server.ShutdownDuration, } // Option: AllowQuerySemicolons diff --git a/pkg/apis/options/legacy_options.go b/pkg/apis/options/legacy_options.go index 6ce730fd..e0ef0716 100644 --- a/pkg/apis/options/legacy_options.go +++ b/pkg/apis/options/legacy_options.go @@ -471,16 +471,17 @@ func getXAuthRequestAccessTokenHeader() Header { } type LegacyServer struct { - MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"` - MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_secure_address"` - MetricsTLSCertFile string `flag:"metrics-tls-cert-file" cfg:"metrics_tls_cert_file"` - MetricsTLSKeyFile string `flag:"metrics-tls-key-file" cfg:"metrics_tls_key_file"` - HTTPAddress string `flag:"http-address" cfg:"http_address"` - HTTPSAddress string `flag:"https-address" cfg:"https_address"` - TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"` - TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"` - TLSMinVersion string `flag:"tls-min-version" cfg:"tls_min_version"` - TLSCipherSuites []string `flag:"tls-cipher-suite" cfg:"tls_cipher_suites"` + MetricsAddress string `flag:"metrics-address" cfg:"metrics_address"` + MetricsSecureAddress string `flag:"metrics-secure-address" cfg:"metrics_secure_address"` + MetricsTLSCertFile string `flag:"metrics-tls-cert-file" cfg:"metrics_tls_cert_file"` + MetricsTLSKeyFile string `flag:"metrics-tls-key-file" cfg:"metrics_tls_key_file"` + HTTPAddress string `flag:"http-address" cfg:"http_address"` + HTTPSAddress string `flag:"https-address" cfg:"https_address"` + TLSCertFile string `flag:"tls-cert-file" cfg:"tls_cert_file"` + TLSKeyFile string `flag:"tls-key-file" cfg:"tls_key_file"` + TLSMinVersion string `flag:"tls-min-version" cfg:"tls_min_version"` + TLSCipherSuites []string `flag:"tls-cipher-suite" cfg:"tls_cipher_suites"` + ShutdownDuration time.Duration `flag:"shutdown-duration" cfg:"shutdown_duration"` } func legacyServerFlagset() *pflag.FlagSet { @@ -496,6 +497,7 @@ func legacyServerFlagset() *pflag.FlagSet { flagSet.String("tls-key-file", "", "path to private key file") flagSet.String("tls-min-version", "", "minimal TLS version for HTTPS clients (either \"TLS1.2\" or \"TLS1.3\")") flagSet.StringSlice("tls-cipher-suite", []string{}, "restricts TLS cipher suites to those listed (e.g. TLS_RSA_WITH_RC4_128_SHA) (may be given multiple times)") + flagSet.Duration("shutdown-duration", 0, "Amount of time to continue serving traffic after receiving an exit signal with readiness endpoint set to false.") return flagSet } @@ -650,6 +652,7 @@ func (l LegacyServer) convert() (Server, Server) { appServer := Server{ BindAddress: l.HTTPAddress, SecureBindAddress: l.HTTPSAddress, + ShutdownDuration: l.ShutdownDuration, } if l.TLSKeyFile != "" || l.TLSCertFile != "" { appServer.TLS = &TLS{ diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 38d48a37..b57d5aed 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -3,7 +3,6 @@ package options import ( "crypto" "net/url" - "time" ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" @@ -30,8 +29,6 @@ type Options struct { RawRedirectURL string `flag:"redirect-url" cfg:"redirect_url"` RelativeRedirectURL bool `flag:"relative-redirect-url" cfg:"relative_redirect_url"` - ShutdownDuration time.Duration `flag:"shutdown-duration" cfg:"shutdown_duration"` - AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` EmailDomains []string `flag:"email-domain" cfg:"email_domains"` WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"` @@ -164,7 +161,6 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") - flagSet.Duration("shutdown-duration", 0, "Amount of time to continue serving traffic after receiving an exit signal with readiness endpoint set to false.") flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet()) diff --git a/pkg/apis/options/server.go b/pkg/apis/options/server.go index 8fa41af8..a26d3a61 100644 --- a/pkg/apis/options/server.go +++ b/pkg/apis/options/server.go @@ -1,5 +1,7 @@ package options +import "time" + // Server represents the configuration for an HTTP(S) server type Server struct { // BindAddress is the address on which to serve traffic. @@ -13,6 +15,11 @@ type Server struct { // TLS contains the information for loading the certificate and key for the // secure traffic and further configuration for the TLS server. TLS *TLS `yaml:"tls,omitempty"` + + // Duration of time to continue serving traffic after receiving an exit signal. + // During this time the readiness endpoint will be returning HTTP 503 errors. + // Leave blank to disable. + ShutdownDuration time.Duration `yaml:"shutdownDuration,omitempty"` } // TLS contains the information for loading a TLS certificate and key diff --git a/pkg/middleware/readynesscheck.go b/pkg/middleware/readynesscheck.go index 6977cf2e..2179454a 100644 --- a/pkg/middleware/readynesscheck.go +++ b/pkg/middleware/readynesscheck.go @@ -24,6 +24,8 @@ func NewReadynessCheck(ctx context.Context, path string, verifiable Verifiable) func readynessCheck(ctx context.Context, path string, verifiable Verifiable, next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + // Check the server context (not request). + // Has the context been canceled because of SIGTERM? if ctx.Err() != nil { rw.WriteHeader(http.StatusServiceUnavailable) fmt.Fprintf(rw, "Shutting down") diff --git a/pkg/proxyhttp/server.go b/pkg/proxyhttp/server.go index 5372d1c3..6bbebd77 100644 --- a/pkg/proxyhttp/server.go +++ b/pkg/proxyhttp/server.go @@ -222,9 +222,7 @@ func (s *server) startServer(ctx context.Context, listener net.Listener) error { g.Go(func() error { <-groupCtx.Done() logger.Printf("Context canceled. Waiting %s before shutting down the listeners.", s.shutdownDuration) - time.Sleep(s.shutdownDuration) - logger.Printf("Shutting down listener.") if err := srv.Shutdown(context.Background()); err != nil {