Merge befd7e8588 into 9168731c7a
				
					
				
			This commit is contained in:
		
						commit
						08fc986c0c
					
				|  | @ -28,6 +28,7 @@ | ||||||
| - [#3166](https://github.com/oauth2-proxy/oauth2-proxy/pull/3166) chore(dep): upgrade to latest golang 1.24.6 (@tuunit) | - [#3166](https://github.com/oauth2-proxy/oauth2-proxy/pull/3166) chore(dep): upgrade to latest golang 1.24.6 (@tuunit) | ||||||
| - [#3156](https://github.com/oauth2-proxy/oauth2-proxy/pull/3156) feat: allow disable-keep-alives configuration for upstream (@jet-go) | - [#3156](https://github.com/oauth2-proxy/oauth2-proxy/pull/3156) feat: allow disable-keep-alives configuration for upstream (@jet-go) | ||||||
| - [#3150](https://github.com/oauth2-proxy/oauth2-proxy/pull/3150) fix: Gitea team membership (@MagicRB, @tuunit) | - [#3150](https://github.com/oauth2-proxy/oauth2-proxy/pull/3150) fix: Gitea team membership (@MagicRB, @tuunit) | ||||||
|  | - [#2953](https://github.com/oauth2-proxy/oauth2-proxy/pull/2953) feat: reloadable server TLS certificate (@emsixteeen) | ||||||
| 
 | 
 | ||||||
| # V7.11.0 | # V7.11.0 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -36,6 +36,9 @@ There are two recommended configurations: | ||||||
|     If not specified, the defaults from [`crypto/tls`](https://pkg.go.dev/crypto/tls#CipherSuites) of the currently used `go` version for building `oauth2-proxy` will be used. |     If not specified, the defaults from [`crypto/tls`](https://pkg.go.dev/crypto/tls#CipherSuites) of the currently used `go` version for building `oauth2-proxy` will be used. | ||||||
|     A complete list of valid TLS cipher suite names can be found in [`crypto/tls`](https://pkg.go.dev/crypto/tls#pkg-constants). |     A complete list of valid TLS cipher suite names can be found in [`crypto/tls`](https://pkg.go.dev/crypto/tls#pkg-constants). | ||||||
| 
 | 
 | ||||||
|  | 3.  The TLS server certificate and key can be reloaded without restarting `oauth2-proxy` by sending a `SIGHUP` to a running `oauth2-proxy` process. | ||||||
|  |     If the `oauth2-proxy` server encounters a failure while reloading the certificate or key, the existing certificate and key will remain unchanged and an error will be logged.   | ||||||
|  | 
 | ||||||
| ### Terminate TLS at Reverse Proxy, e.g. Nginx | ### Terminate TLS at Reverse Proxy, e.g. Nginx | ||||||
| 
 | 
 | ||||||
| 1.  Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or ... | 1.  Configure SSL Termination with [Nginx](http://nginx.org/) (example config below), Amazon ELB, Google Cloud Platform Load Balancing, or ... | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ import ( | ||||||
| 	. "github.com/onsi/gomega" | 	. "github.com/onsi/gomega" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | var ipv4Addr, ipv6Addr = "127.0.0.1", "::1" | ||||||
| var ipv4CertData, ipv6CertData []byte | var ipv4CertData, ipv6CertData []byte | ||||||
| var ipv4CertDataSource, ipv4KeyDataSource options.SecretSource | var ipv4CertDataSource, ipv4KeyDataSource options.SecretSource | ||||||
| var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource | var ipv6CertDataSource, ipv6KeyDataSource options.SecretSource | ||||||
|  | @ -40,49 +41,72 @@ func httpGet(ctx context.Context, url string) (*http.Response, error) { | ||||||
| 	return c.Do(req) | 	return c.Do(req) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func generateCert(ipaddr string) (certData, certOutBytes, keyOutBytes []byte, err error) { | ||||||
|  | 	certBytes, keyBytes, err := util.GenerateCert(ipaddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	certData = certBytes | ||||||
|  | 
 | ||||||
|  | 	certOut := new(bytes.Buffer) | ||||||
|  | 	if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	certOutBytes = certOut.Bytes() | ||||||
|  | 
 | ||||||
|  | 	keyOut := new(bytes.Buffer) | ||||||
|  | 	if err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}); err != nil { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	keyOutBytes = keyOut.Bytes() | ||||||
|  | 
 | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func generateX509Cert(certSource, keySource options.SecretSource) (*x509.Certificate, error) { | ||||||
|  | 	cert, err := tls.X509KeyPair(certSource.Value, keySource.Value) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	certificate, err := x509.ParseCertificate(cert.Certificate[0]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return certificate, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func addCertToTransportRootCAs(transport *http.Transport, cert ...*x509.Certificate) { | ||||||
|  | 	transport.TLSClientConfig.RootCAs = x509.NewCertPool() | ||||||
|  | 	for _, c := range cert { | ||||||
|  | 		transport.TLSClientConfig.RootCAs.AddCert(c) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| var _ = BeforeSuite(func() { | var _ = BeforeSuite(func() { | ||||||
| 	By("Generating a ipv4 self-signed cert for TLS tests", func() { | 	By("Generating a ipv4 self-signed cert for TLS tests", func() { | ||||||
| 		certBytes, keyBytes, err := util.GenerateCert("127.0.0.1") | 		ipv4Cert, ipv4CertBytes, ipv4KeyBytes, err := generateCert(ipv4Addr) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 		ipv4CertData = certBytes |  | ||||||
| 
 | 
 | ||||||
| 		certOut := new(bytes.Buffer) | 		ipv4CertData, ipv4CertDataSource.Value, ipv4KeyDataSource.Value = ipv4Cert, ipv4CertBytes, ipv4KeyBytes | ||||||
| 		Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) |  | ||||||
| 		ipv4CertDataSource.Value = certOut.Bytes() |  | ||||||
| 		keyOut := new(bytes.Buffer) |  | ||||||
| 		Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes})).To(Succeed()) |  | ||||||
| 		ipv4KeyDataSource.Value = keyOut.Bytes() |  | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	By("Generating a ipv6 self-signed cert for TLS tests", func() { | 	By("Generating a ipv6 self-signed cert for TLS tests", func() { | ||||||
| 		certBytes, keyBytes, err := util.GenerateCert("::1") | 		ipv6Cert, ipv6CertBytes, ipv6KeyBytes, err := generateCert(ipv6Addr) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 		ipv6CertData = certBytes |  | ||||||
| 
 | 
 | ||||||
| 		certOut := new(bytes.Buffer) | 		ipv6CertData, ipv6CertDataSource.Value, ipv6KeyDataSource.Value = ipv6Cert, ipv6CertBytes, ipv6KeyBytes | ||||||
| 		Expect(pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certBytes})).To(Succeed()) |  | ||||||
| 		ipv6CertDataSource.Value = certOut.Bytes() |  | ||||||
| 		keyOut := new(bytes.Buffer) |  | ||||||
| 		Expect(pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes})).To(Succeed()) |  | ||||||
| 		ipv6KeyDataSource.Value = keyOut.Bytes() |  | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	By("Setting up a http client", func() { | 	By("Setting up a http client", func() { | ||||||
| 		ipv4cert, err := tls.X509KeyPair(ipv4CertDataSource.Value, ipv4KeyDataSource.Value) | 		ipv4certificate, err := generateX509Cert(ipv4CertDataSource, ipv4KeyDataSource) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) |  | ||||||
| 		ipv6cert, err := tls.X509KeyPair(ipv6CertDataSource.Value, ipv6KeyDataSource.Value) |  | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 
 | 
 | ||||||
| 		ipv4certificate, err := x509.ParseCertificate(ipv4cert.Certificate[0]) | 		ipv6certificate, err := generateX509Cert(ipv6CertDataSource, ipv6KeyDataSource) | ||||||
| 		Expect(err).ToNot(HaveOccurred()) | 		Expect(err).ToNot(HaveOccurred()) | ||||||
| 		ipv6certificate, err := x509.ParseCertificate(ipv6cert.Certificate[0]) |  | ||||||
| 		Expect(err).ToNot(HaveOccurred()) |  | ||||||
| 
 |  | ||||||
| 		certpool := x509.NewCertPool() |  | ||||||
| 		certpool.AddCert(ipv4certificate) |  | ||||||
| 		certpool.AddCert(ipv6certificate) |  | ||||||
| 
 | 
 | ||||||
| 		transport = http.DefaultTransport.(*http.Transport).Clone() | 		transport = http.DefaultTransport.(*http.Transport).Clone() | ||||||
| 		transport.TLSClientConfig.RootCAs = certpool | 		addCertToTransportRootCAs(transport, ipv4certificate, ipv6certificate) | ||||||
| 	}) | 	}) | ||||||
| }) | }) | ||||||
|  |  | ||||||
|  | @ -8,7 +8,10 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"os/signal" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  | 	"syscall" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"golang.org/x/sync/errgroup" | 	"golang.org/x/sync/errgroup" | ||||||
|  | @ -142,11 +145,12 @@ func (s *server) setupTLSListener(opts Opts) error { | ||||||
| 	if opts.TLS == nil { | 	if opts.TLS == nil { | ||||||
| 		return errors.New("no TLS config provided") | 		return errors.New("no TLS config provided") | ||||||
| 	} | 	} | ||||||
| 	cert, err := getCertificate(opts.TLS) | 
 | ||||||
|  | 	loader, err := getCertificateLoader(opts.TLS) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("could not load certificate: %v", err) | 		return fmt.Errorf("could not load certificate: %v", err) | ||||||
| 	} | 	} | ||||||
| 	config.Certificates = []tls.Certificate{cert} | 	config.GetCertificate = loader.GetCertificate | ||||||
| 
 | 
 | ||||||
| 	if len(opts.TLS.CipherSuites) > 0 { | 	if len(opts.TLS.CipherSuites) > 0 { | ||||||
| 		cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites) | 		cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites) | ||||||
|  | @ -174,7 +178,13 @@ func (s *server) setupTLSListener(opts Opts) error { | ||||||
| 		return fmt.Errorf("listen (%s) failed: %v", listenAddr, err) | 		return fmt.Errorf("listen (%s) failed: %v", listenAddr, err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config) | 	s.tlsListener = reloadableTLSListener{ | ||||||
|  | 		Listener: tls.NewListener( | ||||||
|  | 			tcpKeepAliveListener{listener.(*net.TCPListener)}, | ||||||
|  | 			config, | ||||||
|  | 		), | ||||||
|  | 		loader: loader, | ||||||
|  | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -194,6 +204,21 @@ func (s *server) Start(ctx context.Context) error { | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	if s.tlsListener != nil { | 	if s.tlsListener != nil { | ||||||
|  | 		listener := s.tlsListener.(reloadableTLSListener) | ||||||
|  | 		ch := make(chan os.Signal, 1) | ||||||
|  | 		signal.Notify(ch, syscall.SIGHUP) | ||||||
|  | 		g.Go(func() error { | ||||||
|  | 			for { | ||||||
|  | 				select { | ||||||
|  | 				case <-ch: | ||||||
|  | 					if err := listener.Reload(); err != nil { | ||||||
|  | 						logger.Errorf("Error reloading TLS certificate: %v", err) | ||||||
|  | 					} | ||||||
|  | 				case <-ctx.Done(): | ||||||
|  | 					return nil | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
| 		g.Go(func() error { | 		g.Go(func() error { | ||||||
| 			if err := s.startServer(groupCtx, s.tlsListener); err != nil { | 			if err := s.startServer(groupCtx, s.tlsListener); err != nil { | ||||||
| 				return fmt.Errorf("error starting secure server: %v", err) | 				return fmt.Errorf("error starting secure server: %v", err) | ||||||
|  | @ -253,24 +278,63 @@ func getListenAddress(addr string) string { | ||||||
| 	return slice[len(slice)-1] | 	return slice[len(slice)-1] | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getCertificate loads the certificate data from the TLS config.
 | type tlsLoader struct { | ||||||
| func getCertificate(opts *options.TLS) (tls.Certificate, error) { | 	*options.TLS | ||||||
| 	keyData, err := getSecretValue(opts.Key) | 
 | ||||||
|  | 	mu   sync.Mutex | ||||||
|  | 	cert *tls.Certificate | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *tlsLoader) LoadCert() error { | ||||||
|  | 	t.mu.Lock() | ||||||
|  | 	defer t.mu.Unlock() | ||||||
|  | 
 | ||||||
|  | 	keyData, err := getSecretValue(t.Key) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err) | 		return fmt.Errorf("could not load key data: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	certData, err := getSecretValue(opts.Cert) | 	certData, err := getSecretValue(t.Cert) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err) | 		return fmt.Errorf("could not load cert data: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	cert, err := tls.X509KeyPair(certData, keyData) | 	cert, err := tls.X509KeyPair(certData, keyData) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err) | 		return fmt.Errorf("could not parse certificate data: %v", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return cert, nil | 	t.cert = &cert | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (t *tlsLoader) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { | ||||||
|  | 	if t.cert == nil { | ||||||
|  | 		return nil, fmt.Errorf("no certificate") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return t.cert, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func getCertificateLoader(opts *options.TLS) (*tlsLoader, error) { | ||||||
|  | 	loader := &tlsLoader{ | ||||||
|  | 		TLS: opts, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := loader.LoadCert(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return loader, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type reloadableTLSListener struct { | ||||||
|  | 	net.Listener | ||||||
|  | 	loader *tlsLoader | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (rl reloadableTLSListener) Reload() error { | ||||||
|  | 	return rl.loader.LoadCert() | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // getSecretValue wraps util.GetSecretValue so that we can return an error if no
 | // getSecretValue wraps util.GetSecretValue so that we can return an error if no
 | ||||||
|  |  | ||||||
|  | @ -8,6 +8,7 @@ import ( | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
|  | 	"syscall" | ||||||
| 
 | 
 | ||||||
| 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | 	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" | ||||||
| 	. "github.com/onsi/ginkgo/v2" | 	. "github.com/onsi/ginkgo/v2" | ||||||
|  | @ -835,6 +836,38 @@ var _ = Describe("Server", func() { | ||||||
| 				Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1)) | 				Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1)) | ||||||
| 				Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(ipv4CertData)) | 				Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(ipv4CertData)) | ||||||
| 			}) | 			}) | ||||||
|  | 
 | ||||||
|  | 			It("Reloads the certificate on SIGHUP", func() { | ||||||
|  | 				go func() { | ||||||
|  | 					defer GinkgoRecover() | ||||||
|  | 					Expect(srv.Start(ctx)).To(Succeed()) | ||||||
|  | 				}() | ||||||
|  | 
 | ||||||
|  | 				var err error | ||||||
|  | 
 | ||||||
|  | 				ipv4CertData, ipv4CertDataSource.Value, ipv4KeyDataSource.Value, err = generateCert(ipv4Addr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				ipv6CertData, ipv6CertDataSource.Value, ipv6KeyDataSource.Value, err = generateCert(ipv6Addr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				ipv4Certificate, err := generateX509Cert(ipv4CertDataSource, ipv4KeyDataSource) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				ipv6Certificate, err := generateX509Cert(ipv6CertDataSource, ipv6KeyDataSource) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				addCertToTransportRootCAs(transport, ipv4Certificate, ipv6Certificate) | ||||||
|  | 
 | ||||||
|  | 				err = syscall.Kill(syscall.Getpid(), syscall.SIGHUP) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 
 | ||||||
|  | 				resp, err := httpGet(ctx, secureListenAddr) | ||||||
|  | 				Expect(err).ToNot(HaveOccurred()) | ||||||
|  | 				Expect(resp.StatusCode).To(Equal(http.StatusOK)) | ||||||
|  | 
 | ||||||
|  | 				Expect(resp.TLS.VerifiedChains).Should(HaveLen(1)) | ||||||
|  | 				Expect(resp.TLS.VerifiedChains[0]).Should(HaveLen(1)) | ||||||
|  | 				Expect(resp.TLS.VerifiedChains[0][0].Raw).Should(Equal(ipv4CertData)) | ||||||
|  | 			}) | ||||||
| 		}) | 		}) | ||||||
| 
 | 
 | ||||||
| 		Context("with a fd ipv4 http and an ipv4 https server", func() { | 		Context("with a fd ipv4 http and an ipv4 https server", func() { | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue