diff --git a/main.go b/main.go index cf7e964c..4bef9ad4 100644 --- a/main.go +++ b/main.go @@ -54,6 +54,7 @@ func main() { } validator := NewValidator(opts.EmailDomains, opts.AuthenticatedEmailsFile) + oauthproxy, err := NewOAuthProxy(opts, validator) if err != nil { logger.Fatalf("ERROR: Failed to initialise OAuth2 Proxy: %v", err) diff --git a/oauthproxy.go b/oauthproxy.go index b85c89b4..d8ec273e 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -113,6 +113,9 @@ type OAuthProxy struct { redirectValidator redirect.Validator appDirector redirect.AppDirector + cancelFunc context.CancelFunc + cancelCtx context.Context + encodeState bool } @@ -200,7 +203,9 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, err } - preAuthChain, err := buildPreAuthChain(opts, sessionStore) + cancelCtx, cancelFunc := context.WithCancel(context.Background()) + + preAuthChain, err := buildPreAuthChain(opts, sessionStore, cancelCtx) if err != nil { return nil, fmt.Errorf("could not build pre-auth chain: %v", err) } @@ -248,6 +253,8 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr redirectValidator: redirectValidator, appDirector: appDirector, encodeState: opts.EncodeState, + cancelFunc: cancelFunc, + cancelCtx: cancelCtx, } p.buildServeMux(opts.ProxyPrefix) @@ -265,17 +272,15 @@ func (p *OAuthProxy) Start() error { panic("server has not been initialised") } - ctx, cancel := context.WithCancel(context.Background()) - // Observe signals in background goroutine. go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt, syscall.SIGTERM) <-sigint - cancel() // cancel the context + p.cancelFunc() // cancel the context }() - return p.server.Start(ctx) + return p.server.Start(p.cancelCtx) } func (p *OAuthProxy) setupServer(opts *options.Options) error { @@ -284,6 +289,7 @@ func (p *OAuthProxy) setupServer(opts *options.Options) error { BindAddress: opts.Server.BindAddress, SecureBindAddress: opts.Server.SecureBindAddress, TLS: opts.Server.TLS, + ShutdownDuration: opts.ShutdownDuration, } // Option: AllowQuerySemicolons @@ -353,7 +359,7 @@ func (p *OAuthProxy) buildProxySubrouter(s *mux.Router) { // buildPreAuthChain constructs a chain that should process every request before // the OAuth2 Proxy authentication logic kicks in. // For example forcing HTTPS or health checks. -func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore) (alice.Chain, error) { +func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionStore, ctx context.Context) (alice.Chain, error) { chain := alice.New(middleware.NewScope(opts.ReverseProxy, opts.Logging.RequestIDHeader)) if opts.ForceHTTPS { @@ -374,17 +380,18 @@ func buildPreAuthChain(opts *options.Options, sessionStore sessionsapi.SessionSt // To silence logging of health checks, register the health check handler before // the logging handler + readinessCheck := middleware.NewReadynessCheck(opts.ReadyPath, sessionStore, ctx) if opts.Logging.SilencePing { chain = chain.Append( middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), - middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), + readinessCheck, middleware.NewRequestLogger(), ) } else { chain = chain.Append( middleware.NewRequestLogger(), middleware.NewHealthCheck(healthCheckPaths, healthCheckUserAgents), - middleware.NewReadynessCheck(opts.ReadyPath, sessionStore), + readinessCheck, ) } diff --git a/pkg/apis/options/options.go b/pkg/apis/options/options.go index 8fa72c7c..eef86f7e 100644 --- a/pkg/apis/options/options.go +++ b/pkg/apis/options/options.go @@ -3,6 +3,7 @@ package options import ( "crypto" "net/url" + "time" ipapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/ip" internaloidc "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/providers/oidc" @@ -29,6 +30,8 @@ type Options struct { RawRedirectURL string `flag:"redirect-url" cfg:"redirect_url"` RelativeRedirectURL bool `flag:"relative-redirect-url" cfg:"relative_redirect_url"` + ShutdownDuration time.Duration `flag:"shutdown-duration" cfg:"shutdown_duration"` + AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` EmailDomains []string `flag:"email-domain" cfg:"email_domains"` WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains"` @@ -161,6 +164,7 @@ func NewFlagSet() *pflag.FlagSet { flagSet.Int("redis-connection-idle-timeout", 0, "Redis connection idle timeout seconds, if Redis timeout option is non-zero, the --redis-connection-idle-timeout must be less then Redis timeout option") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.Bool("gcp-healthchecks", false, "Enable GCP/GKE healthcheck endpoints") + flagSet.Duration("shutdown-duration", 0, "Amount of time to continue serving traffic after receiving an exit signal with readiness endpoint set to false.") flagSet.AddFlagSet(cookieFlagSet()) flagSet.AddFlagSet(loggingFlagSet()) diff --git a/pkg/http/server.go b/pkg/http/server.go index fe76427a..25b1b0df 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -40,12 +40,16 @@ type Opts struct { // Let testing infrastructure circumvent parsing file descriptors fdFiles []*os.File + + // Graceful shutdown duration + ShutdownDuration time.Duration } // NewServer creates a new Server from the options given. func NewServer(opts Opts) (Server, error) { s := &server{ - handler: opts.Handler, + handler: opts.Handler, + shutdownDuration: opts.ShutdownDuration, } if len(opts.fdFiles) > 0 { @@ -71,6 +75,9 @@ type server struct { // ensure activation.Files are called once fdFiles []*os.File + + // Graceful shutdown duration + shutdownDuration time.Duration } // setupListener sets the server listener if the HTTP server is enabled. @@ -214,10 +221,16 @@ func (s *server) startServer(ctx context.Context, listener net.Listener) error { g.Go(func() error { <-groupCtx.Done() + logger.Printf("Context canceled. Waiting %s before shutting down the listeners.", s.shutdownDuration) + + time.Sleep(s.shutdownDuration) + + logger.Printf("Shutting down listener.") if err := srv.Shutdown(context.Background()); err != nil { return fmt.Errorf("error shutting down server: %v", err) } + return nil }) diff --git a/pkg/middleware/readynesscheck.go b/pkg/middleware/readynesscheck.go index aee871e1..bf3e6373 100644 --- a/pkg/middleware/readynesscheck.go +++ b/pkg/middleware/readynesscheck.go @@ -16,14 +16,19 @@ type Verifiable interface { // NewReadynessCheck returns a middleware that performs deep health checks // (verifies the connection to any underlying store) on a specific `path` -func NewReadynessCheck(path string, verifiable Verifiable) alice.Constructor { +func NewReadynessCheck(path string, verifiable Verifiable, ctx context.Context) alice.Constructor { return func(next http.Handler) http.Handler { - return readynessCheck(path, verifiable, next) + return readynessCheck(path, verifiable, next, ctx) } } -func readynessCheck(path string, verifiable Verifiable, next http.Handler) http.Handler { +func readynessCheck(path string, verifiable Verifiable, next http.Handler, ctx context.Context) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if ctx.Err() != nil { + rw.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(rw, "Shutting down") + return + } if path != "" && req.URL.EscapedPath() == path { if err := verifiable.VerifyConnection(req.Context()); err != nil { rw.WriteHeader(http.StatusInternalServerError)