From 2febf94e0dcd5f02c8d5e89c2a93d851520b316f Mon Sep 17 00:00:00 2001 From: frhack Date: Thu, 12 Feb 2026 19:07:45 +0100 Subject: [PATCH] feat: add built-in health check command for Docker HEALTHCHECK support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a `health` subcommand and `--healthcheck` flag that performs an HTTP GET against the running oauth2-proxy's /ping endpoint, returning exit 0 on HTTP 200 and exit 1 otherwise. This eliminates the need to add curl or wget to distroless container images for Docker health checks. Two invocation modes: - `oauth2-proxy health` (subcommand with own flags) - `oauth2-proxy --healthcheck` (reads proxy config for address/ping-path) The health subcommand supports: - --http-address / --https-address to target the correct listener - --ping-path for custom ping endpoint paths - --timeout for configurable request timeout (default 5s) - --insecure-skip-verify for self-signed TLS certificates - Automatic wildcard-to-loopback address translation (0.0.0.0 -> 127.0.0.1) - HTTP-first with HTTPS fallback Also adds HEALTHCHECK instruction to the Dockerfile. Zero external dependencies - uses only net/http from the Go standard library. Closes #2555 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- Dockerfile | 3 + main.go | 74 ++++++ pkg/healthcheck/healthcheck.go | 156 ++++++++++++ pkg/healthcheck/healthcheck_suite_test.go | 13 + pkg/healthcheck/healthcheck_test.go | 286 ++++++++++++++++++++++ 5 files changed, 532 insertions(+) create mode 100644 pkg/healthcheck/healthcheck.go create mode 100644 pkg/healthcheck/healthcheck_suite_test.go create mode 100644 pkg/healthcheck/healthcheck_test.go diff --git a/Dockerfile b/Dockerfile index 70c744af..5635b31a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -71,4 +71,7 @@ LABEL org.opencontainers.image.licenses=MIT \ org.opencontainers.image.title=oauth2-proxy \ org.opencontainers.image.version=${VERSION} +HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \ + CMD ["/bin/oauth2-proxy", "health"] + ENTRYPOINT ["/bin/oauth2-proxy"] diff --git a/main.go b/main.go index 42e8bab0..381ea8b8 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,10 @@ import ( "fmt" "os" "runtime" + "time" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/healthcheck" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/validation" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/version" @@ -16,6 +18,12 @@ import ( func main() { logger.SetFlags(logger.Lshortfile) + // Check if "health" subcommand is being invoked (e.g., "oauth2-proxy health") + if len(os.Args) > 1 && os.Args[1] == "health" { + runHealthCheck(os.Args[2:]) + return + } + configFlagSet := pflag.NewFlagSet("oauth2-proxy", pflag.ContinueOnError) // Because we parse early to determine alpha vs legacy config, we have to @@ -26,6 +34,7 @@ func main() { alphaConfig := configFlagSet.String("alpha-config", "", "path to alpha config file (use at your own risk - the structure in this config file may change between minor releases)") convertConfig := configFlagSet.Bool("convert-config-to-alpha", false, "if true, the proxy will load configuration as normal and convert existing configuration to the alpha config structure, and print it to stdout") showVersion := configFlagSet.Bool("version", false, "print version string") + checkHealth := configFlagSet.Bool("healthcheck", false, "perform a health check against a running oauth2-proxy instance and exit") configFlagSet.Parse(os.Args[1:]) if *showVersion { @@ -33,6 +42,11 @@ func main() { return } + if *checkHealth { + runHealthCheckFromConfig(*config, *alphaConfig, configFlagSet, os.Args[1:]) + return + } + if *convertConfig && *alphaConfig != "" { logger.Fatal("cannot use alpha-config and convert-config-to-alpha together") } @@ -64,6 +78,66 @@ func main() { } } +// runHealthCheck handles the "health" subcommand with its own flag set. +func runHealthCheck(args []string) { + fs := pflag.NewFlagSet("health", pflag.ContinueOnError) + httpAddr := fs.String("http-address", healthcheck.DefaultHTTPAddress, "HTTP address of the oauth2-proxy instance to check") + httpsAddr := fs.String("https-address", "", "HTTPS address of the oauth2-proxy instance to check") + pingPath := fs.String("ping-path", healthcheck.DefaultPingPath, "path of the ping endpoint") + timeout := fs.Duration("timeout", healthcheck.DefaultTimeout, "timeout for the health check request") + insecure := fs.Bool("insecure-skip-verify", false, "skip TLS certificate verification for HTTPS health checks") + + if err := fs.Parse(args); err != nil { + logger.Fatalf("ERROR: %v", err) + } + + opts := healthcheck.CheckOptions{ + HTTPAddress: *httpAddr, + HTTPSAddress: *httpsAddr, + PingPath: *pingPath, + Timeout: *timeout, + InsecureSkipVerify: *insecure, + } + + if err := healthcheck.Run(opts); err != nil { + fmt.Fprintf(os.Stderr, "healthcheck failed: %v\n", err) + os.Exit(1) + } + + fmt.Println("OK") +} + +// runHealthCheckFromConfig performs a health check using the loaded configuration. +// This supports the --healthcheck flag which respects the same configuration as the proxy. +func runHealthCheckFromConfig(config, alphaConfig string, extraFlags *pflag.FlagSet, args []string) { + opts, err := loadConfiguration(config, alphaConfig, extraFlags, args) + if err != nil { + // If config loading fails, fall back to defaults + logger.Printf("WARNING: failed to load configuration: %v; using defaults", err) + checkOpts := healthcheck.DefaultCheckOptions() + if err := healthcheck.Run(checkOpts); err != nil { + fmt.Fprintf(os.Stderr, "healthcheck failed: %v\n", err) + os.Exit(1) + } + fmt.Println("OK") + return + } + + checkOpts := healthcheck.CheckOptions{ + HTTPAddress: opts.Server.BindAddress, + HTTPSAddress: opts.Server.SecureBindAddress, + PingPath: opts.PingPath, + Timeout: 5 * time.Second, + } + + if err := healthcheck.Run(checkOpts); err != nil { + fmt.Fprintf(os.Stderr, "healthcheck failed: %v\n", err) + os.Exit(1) + } + + fmt.Println("OK") +} + // loadConfiguration will load in the user's configuration. // It will either load the alpha configuration (if alphaConfig is given) // or the legacy configuration. diff --git a/pkg/healthcheck/healthcheck.go b/pkg/healthcheck/healthcheck.go new file mode 100644 index 00000000..cb8d2376 --- /dev/null +++ b/pkg/healthcheck/healthcheck.go @@ -0,0 +1,156 @@ +package healthcheck + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" +) + +const ( + // DefaultHTTPAddress is the default bind address for the HTTP server. + DefaultHTTPAddress = "127.0.0.1:4180" + + // DefaultPingPath is the default path for the ping endpoint. + DefaultPingPath = "/ping" + + // DefaultTimeout is the default timeout for the health check request. + DefaultTimeout = 5 * time.Second +) + +// CheckOptions holds configuration for a health check request. +type CheckOptions struct { + // HTTPAddress is the address the oauth2-proxy HTTP server is bound to. + // Format: [http://]: + HTTPAddress string + + // HTTPSAddress is the address the oauth2-proxy HTTPS server is bound to. + // Format: : + HTTPSAddress string + + // PingPath is the URL path for the ping endpoint. + PingPath string + + // Timeout is the maximum duration for the health check request. + Timeout time.Duration + + // InsecureSkipVerify skips TLS certificate verification for HTTPS checks. + InsecureSkipVerify bool +} + +// DefaultCheckOptions returns CheckOptions with sensible defaults. +func DefaultCheckOptions() CheckOptions { + return CheckOptions{ + HTTPAddress: DefaultHTTPAddress, + HTTPSAddress: "", + PingPath: DefaultPingPath, + Timeout: DefaultTimeout, + } +} + +// Run performs the health check and returns nil on success or an error on failure. +// It checks the HTTP address first. If the HTTP address is empty or disabled, +// it falls back to the HTTPS address. +func Run(opts CheckOptions) error { + if opts.PingPath == "" { + opts.PingPath = DefaultPingPath + } + if opts.Timeout == 0 { + opts.Timeout = DefaultTimeout + } + + httpAddr := normalizeAddress(opts.HTTPAddress) + httpsAddr := normalizeAddress(opts.HTTPSAddress) + + // Try HTTP first, then HTTPS + if httpAddr != "" && httpAddr != "-" { + return checkEndpoint("http", httpAddr, opts.PingPath, opts.Timeout, opts.InsecureSkipVerify) + } + + if httpsAddr != "" && httpsAddr != "-" { + return checkEndpoint("https", httpsAddr, opts.PingPath, opts.Timeout, opts.InsecureSkipVerify) + } + + return fmt.Errorf("no bind address configured; cannot perform health check") +} + +// normalizeAddress strips an optional scheme prefix and returns the host:port. +func normalizeAddress(addr string) string { + addr = strings.TrimSpace(addr) + // Strip optional scheme prefix (e.g., "http://127.0.0.1:4180") + for _, prefix := range []string{"http://", "https://"} { + if strings.HasPrefix(strings.ToLower(addr), prefix) { + addr = addr[len(prefix):] + break + } + } + return addr +} + +// checkEndpoint performs a GET request against scheme://addr/pingPath and validates +// that the response status is 200 OK. +func checkEndpoint(scheme, addr, pingPath string, timeout time.Duration, insecureSkipVerify bool) error { + // Replace unspecified addresses with loopback so the check connects locally. + host, port, err := net.SplitHostPort(addr) + if err != nil { + return fmt.Errorf("invalid address %q: %v", addr, err) + } + + host = replaceUnspecified(host) + target := net.JoinHostPort(host, port) + + url := fmt.Sprintf("%s://%s%s", scheme, target, pingPath) + + client := &http.Client{ + Timeout: timeout, + // Do not follow redirects; we expect a direct 200 from the ping endpoint. + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + if scheme == "https" { + client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: insecureSkipVerify, //nolint:gosec // intentional for local health check against self-signed certs + }, + } + } + + resp, err := client.Get(url) //nolint:gosec // URL is constructed from known configuration, not user input + if err != nil { + return fmt.Errorf("health check failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + return fmt.Errorf("health check returned status %d: %s", resp.StatusCode, string(body)) + } + + return nil +} + +// replaceUnspecified replaces unspecified (wildcard) addresses with their +// loopback equivalents so the health check connects locally. +func replaceUnspecified(host string) string { + switch host { + case "", "0.0.0.0": + return "127.0.0.1" + case "::", "[::]": + return "::1" + } + // Strip brackets from IPv6 addresses that net.SplitHostPort already handled + host = strings.Trim(host, "[]") + ip := net.ParseIP(host) + if ip != nil && ip.IsUnspecified() { + if ip.To4() != nil { + return "127.0.0.1" + } + return "::1" + } + return host +} diff --git a/pkg/healthcheck/healthcheck_suite_test.go b/pkg/healthcheck/healthcheck_suite_test.go new file mode 100644 index 00000000..0f31c562 --- /dev/null +++ b/pkg/healthcheck/healthcheck_suite_test.go @@ -0,0 +1,13 @@ +package healthcheck + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestHealthcheckSuite(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Healthcheck") +} diff --git a/pkg/healthcheck/healthcheck_test.go b/pkg/healthcheck/healthcheck_test.go new file mode 100644 index 00000000..21745529 --- /dev/null +++ b/pkg/healthcheck/healthcheck_test.go @@ -0,0 +1,286 @@ +package healthcheck + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Healthcheck", func() { + Describe("normalizeAddress", func() { + type normalizeInput struct { + input string + expected string + } + + DescribeTable("should strip scheme prefixes and whitespace", + func(in normalizeInput) { + Expect(normalizeAddress(in.input)).To(Equal(in.expected)) + }, + Entry("plain address", normalizeInput{ + input: "127.0.0.1:4180", expected: "127.0.0.1:4180", + }), + Entry("with http scheme", normalizeInput{ + input: "http://127.0.0.1:4180", expected: "127.0.0.1:4180", + }), + Entry("with https scheme", normalizeInput{ + input: "https://127.0.0.1:443", expected: "127.0.0.1:443", + }), + Entry("with leading whitespace", normalizeInput{ + input: " 127.0.0.1:4180", expected: "127.0.0.1:4180", + }), + Entry("empty string", normalizeInput{ + input: "", expected: "", + }), + Entry("disabled address", normalizeInput{ + input: "-", expected: "-", + }), + ) + }) + + Describe("replaceUnspecified", func() { + type replaceInput struct { + input string + expected string + } + + DescribeTable("should replace unspecified addresses with loopback", + func(in replaceInput) { + Expect(replaceUnspecified(in.input)).To(Equal(in.expected)) + }, + Entry("empty string", replaceInput{ + input: "", expected: "127.0.0.1", + }), + Entry("IPv4 unspecified", replaceInput{ + input: "0.0.0.0", expected: "127.0.0.1", + }), + Entry("IPv6 unspecified (::)", replaceInput{ + input: "::", expected: "::1", + }), + Entry("IPv6 unspecified with brackets", replaceInput{ + input: "[::]", expected: "::1", + }), + Entry("IPv4 localhost", replaceInput{ + input: "127.0.0.1", expected: "127.0.0.1", + }), + Entry("specific IPv4 address", replaceInput{ + input: "10.0.0.1", expected: "10.0.0.1", + }), + ) + }) + + Describe("Run", func() { + var ( + server *httptest.Server + listener net.Listener + ) + + AfterEach(func() { + if server != nil { + server.Close() + } + }) + + It("should succeed when ping endpoint returns 200", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ping" { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "OK") + return + } + w.WriteHeader(http.StatusNotFound) + })) + + // Extract host:port from the test server URL + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: addr, + PingPath: "/ping", + Timeout: 2 * time.Second, + } + + Expect(Run(opts)).To(Succeed()) + }) + + It("should fail when ping endpoint returns non-200", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprint(w, "not ready") + })) + + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: addr, + PingPath: "/ping", + Timeout: 2 * time.Second, + } + + err := Run(opts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("status 503")) + }) + + It("should fail when server is not reachable", func() { + // Use a random port that is unlikely to have a server + listener, _ = net.Listen("tcp", "127.0.0.1:0") + addr := listener.Addr().String() + listener.Close() // Close immediately so the port is free but nothing is listening + + opts := CheckOptions{ + HTTPAddress: addr, + PingPath: "/ping", + Timeout: 1 * time.Second, + } + + err := Run(opts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("health check failed")) + }) + + It("should fail when no address is configured", func() { + opts := CheckOptions{ + HTTPAddress: "", + HTTPSAddress: "", + PingPath: "/ping", + Timeout: 1 * time.Second, + } + + err := Run(opts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no bind address configured")) + }) + + It("should fail when address is disabled with -", func() { + opts := CheckOptions{ + HTTPAddress: "-", + HTTPSAddress: "-", + PingPath: "/ping", + Timeout: 1 * time.Second, + } + + err := Run(opts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("no bind address configured")) + }) + + It("should use default ping path when not specified", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ping" { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "OK") + return + } + w.WriteHeader(http.StatusNotFound) + })) + + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: addr, + PingPath: "", // should default to /ping + Timeout: 2 * time.Second, + } + + Expect(Run(opts)).To(Succeed()) + }) + + It("should use a custom ping path", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/healthz" { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "OK") + return + } + w.WriteHeader(http.StatusNotFound) + })) + + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: addr, + PingPath: "/healthz", + Timeout: 2 * time.Second, + } + + Expect(Run(opts)).To(Succeed()) + }) + + It("should handle address with http:// scheme prefix", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "OK") + })) + + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: "http://" + addr, + PingPath: "/ping", + Timeout: 2 * time.Second, + } + + Expect(Run(opts)).To(Succeed()) + }) + + It("should fall back to HTTPS when HTTP address is empty", func() { + server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/ping" { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "OK") + return + } + w.WriteHeader(http.StatusNotFound) + })) + + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: "", + HTTPSAddress: addr, + PingPath: "/ping", + Timeout: 2 * time.Second, + InsecureSkipVerify: true, + } + + Expect(Run(opts)).To(Succeed()) + }) + + It("should respect timeout", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate a slow server + time.Sleep(3 * time.Second) + w.WriteHeader(http.StatusOK) + })) + + addr := server.Listener.Addr().String() + + opts := CheckOptions{ + HTTPAddress: addr, + PingPath: "/ping", + Timeout: 500 * time.Millisecond, + } + + err := Run(opts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("health check failed")) + }) + }) + + Describe("DefaultCheckOptions", func() { + It("should return sensible defaults", func() { + opts := DefaultCheckOptions() + Expect(opts.HTTPAddress).To(Equal(DefaultHTTPAddress)) + Expect(opts.PingPath).To(Equal(DefaultPingPath)) + Expect(opts.Timeout).To(Equal(DefaultTimeout)) + Expect(opts.HTTPSAddress).To(BeEmpty()) + Expect(opts.InsecureSkipVerify).To(BeFalse()) + }) + }) +})