reloadable server tls cert

This commit is contained in:
Michael Katzenellenbogen 2025-02-09 12:24:04 -05:00
parent 5082db0bec
commit abe2dd0d17
1 changed files with 75 additions and 11 deletions

View File

@ -8,7 +8,10 @@ import (
"net"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/sync/errgroup"
@ -142,11 +145,12 @@ func (s *server) setupTLSListener(opts Opts) error {
if opts.TLS == nil {
return errors.New("no TLS config provided")
}
cert, err := getCertificate(opts.TLS)
l, err := getCertificateLoader(opts.TLS)
if err != nil {
return fmt.Errorf("could not load certificate: %v", err)
}
config.Certificates = []tls.Certificate{cert}
config.GetCertificate = l.GetCertificate
if len(opts.TLS.CipherSuites) > 0 {
cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites)
@ -174,7 +178,11 @@ func (s *server) setupTLSListener(opts Opts) error {
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
}
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
ka := tcpKeepAliveListener{listener.(*net.TCPListener)}
s.tlsListener = reloadableTLSListener{
Listener: tls.NewListener(ka, config),
loader: l,
}
return nil
}
@ -194,6 +202,23 @@ func (s *server) Start(ctx context.Context) error {
}
if s.tlsListener != nil {
rl := s.tlsListener.(reloadableTLSListener)
ch := make(chan os.Signal, 1)
g.Go(func() error {
for {
select {
case <-ch:
if err := rl.Reload(); err != nil {
logger.Errorf("Error reloading TLS certificate: %v", err)
}
case <-ctx.Done():
return nil
}
}
})
signal.Notify(ch, syscall.SIGHUP)
g.Go(func() error {
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
return fmt.Errorf("error starting secure server: %v", err)
@ -253,24 +278,63 @@ func getListenAddress(addr string) string {
return slice[len(slice)-1]
}
// getCertificate loads the certificate data from the TLS config.
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
keyData, err := getSecretValue(opts.Key)
type tlsLoader struct {
*options.TLS
mu sync.Mutex
cert *tls.Certificate
}
func (t *tlsLoader) LoadCert() error {
t.mu.Lock()
defer t.mu.Unlock()
keyData, err := getSecretValue(t.Key)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
return fmt.Errorf("could not load key data: %v", err)
}
certData, err := getSecretValue(opts.Cert)
certData, err := getSecretValue(t.Cert)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
return fmt.Errorf("could not load cert data: %v", err)
}
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
return fmt.Errorf("could not parse certificate data: %v", err)
}
return cert, nil
t.cert = &cert
return nil
}
func (t *tlsLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if t.cert == nil {
return nil, fmt.Errorf("no certificate")
}
return t.cert, nil
}
func getCertificateLoader(opts *options.TLS) (*tlsLoader, error) {
l := &tlsLoader{
TLS: opts,
}
if err := l.LoadCert(); err != nil {
return nil, err
}
return l, nil
}
type reloadableTLSListener struct {
net.Listener
loader *tlsLoader
}
func (rl reloadableTLSListener) Reload() error {
return rl.loader.LoadCert()
}
// getSecretValue wraps util.GetSecretValue so that we can return an error if no