reloadable server tls cert
This commit is contained in:
parent
5082db0bec
commit
abe2dd0d17
|
|
@ -8,7 +8,10 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
|
@ -142,11 +145,12 @@ func (s *server) setupTLSListener(opts Opts) error {
|
|||
if opts.TLS == nil {
|
||||
return errors.New("no TLS config provided")
|
||||
}
|
||||
cert, err := getCertificate(opts.TLS)
|
||||
|
||||
l, err := getCertificateLoader(opts.TLS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not load certificate: %v", err)
|
||||
}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
config.GetCertificate = l.GetCertificate
|
||||
|
||||
if len(opts.TLS.CipherSuites) > 0 {
|
||||
cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites)
|
||||
|
|
@ -174,7 +178,11 @@ func (s *server) setupTLSListener(opts Opts) error {
|
|||
return fmt.Errorf("listen (%s) failed: %v", listenAddr, err)
|
||||
}
|
||||
|
||||
s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config)
|
||||
ka := tcpKeepAliveListener{listener.(*net.TCPListener)}
|
||||
s.tlsListener = reloadableTLSListener{
|
||||
Listener: tls.NewListener(ka, config),
|
||||
loader: l,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -194,6 +202,23 @@ func (s *server) Start(ctx context.Context) error {
|
|||
}
|
||||
|
||||
if s.tlsListener != nil {
|
||||
rl := s.tlsListener.(reloadableTLSListener)
|
||||
ch := make(chan os.Signal, 1)
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
if err := rl.Reload(); err != nil {
|
||||
logger.Errorf("Error reloading TLS certificate: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
signal.Notify(ch, syscall.SIGHUP)
|
||||
|
||||
g.Go(func() error {
|
||||
if err := s.startServer(groupCtx, s.tlsListener); err != nil {
|
||||
return fmt.Errorf("error starting secure server: %v", err)
|
||||
|
|
@ -253,24 +278,63 @@ func getListenAddress(addr string) string {
|
|||
return slice[len(slice)-1]
|
||||
}
|
||||
|
||||
// getCertificate loads the certificate data from the TLS config.
|
||||
func getCertificate(opts *options.TLS) (tls.Certificate, error) {
|
||||
keyData, err := getSecretValue(opts.Key)
|
||||
type tlsLoader struct {
|
||||
*options.TLS
|
||||
|
||||
mu sync.Mutex
|
||||
cert *tls.Certificate
|
||||
}
|
||||
|
||||
func (t *tlsLoader) LoadCert() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
keyData, err := getSecretValue(t.Key)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err)
|
||||
return fmt.Errorf("could not load key data: %v", err)
|
||||
}
|
||||
|
||||
certData, err := getSecretValue(opts.Cert)
|
||||
certData, err := getSecretValue(t.Cert)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err)
|
||||
return fmt.Errorf("could not load cert data: %v", err)
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(certData, keyData)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err)
|
||||
return fmt.Errorf("could not parse certificate data: %v", err)
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
t.cert = &cert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *tlsLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if t.cert == nil {
|
||||
return nil, fmt.Errorf("no certificate")
|
||||
}
|
||||
|
||||
return t.cert, nil
|
||||
}
|
||||
|
||||
func getCertificateLoader(opts *options.TLS) (*tlsLoader, error) {
|
||||
l := &tlsLoader{
|
||||
TLS: opts,
|
||||
}
|
||||
|
||||
if err := l.LoadCert(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
type reloadableTLSListener struct {
|
||||
net.Listener
|
||||
loader *tlsLoader
|
||||
}
|
||||
|
||||
func (rl reloadableTLSListener) Reload() error {
|
||||
return rl.loader.LoadCert()
|
||||
}
|
||||
|
||||
// getSecretValue wraps util.GetSecretValue so that we can return an error if no
|
||||
|
|
|
|||
Loading…
Reference in New Issue