diff --git a/pkg/http/server.go b/pkg/http/server.go index fe76427a..df066f74 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -8,7 +8,10 @@ import ( "net" "net/http" "os" + "os/signal" "strings" + "sync" + "syscall" "time" "golang.org/x/sync/errgroup" @@ -142,11 +145,12 @@ func (s *server) setupTLSListener(opts Opts) error { if opts.TLS == nil { return errors.New("no TLS config provided") } - cert, err := getCertificate(opts.TLS) + + l, err := getCertificateLoader(opts.TLS) if err != nil { return fmt.Errorf("could not load certificate: %v", err) } - config.Certificates = []tls.Certificate{cert} + config.GetCertificate = l.GetCertificate if len(opts.TLS.CipherSuites) > 0 { cipherSuites, err := parseCipherSuites(opts.TLS.CipherSuites) @@ -174,7 +178,11 @@ func (s *server) setupTLSListener(opts Opts) error { return fmt.Errorf("listen (%s) failed: %v", listenAddr, err) } - s.tlsListener = tls.NewListener(tcpKeepAliveListener{listener.(*net.TCPListener)}, config) + ka := tcpKeepAliveListener{listener.(*net.TCPListener)} + s.tlsListener = reloadableTLSListener{ + Listener: tls.NewListener(ka, config), + loader: l, + } return nil } @@ -194,6 +202,23 @@ func (s *server) Start(ctx context.Context) error { } if s.tlsListener != nil { + rl := s.tlsListener.(reloadableTLSListener) + ch := make(chan os.Signal, 1) + + g.Go(func() error { + for { + select { + case <-ch: + if err := rl.Reload(); err != nil { + logger.Errorf("Error reloading TLS certificate: %v", err) + } + case <-ctx.Done(): + return nil + } + } + }) + signal.Notify(ch, syscall.SIGHUP) + g.Go(func() error { if err := s.startServer(groupCtx, s.tlsListener); err != nil { return fmt.Errorf("error starting secure server: %v", err) @@ -253,24 +278,63 @@ func getListenAddress(addr string) string { return slice[len(slice)-1] } -// getCertificate loads the certificate data from the TLS config. -func getCertificate(opts *options.TLS) (tls.Certificate, error) { - keyData, err := getSecretValue(opts.Key) +type tlsLoader struct { + *options.TLS + + mu sync.Mutex + cert *tls.Certificate +} + +func (t *tlsLoader) LoadCert() error { + t.mu.Lock() + defer t.mu.Unlock() + + keyData, err := getSecretValue(t.Key) if err != nil { - return tls.Certificate{}, fmt.Errorf("could not load key data: %v", err) + return fmt.Errorf("could not load key data: %v", err) } - certData, err := getSecretValue(opts.Cert) + certData, err := getSecretValue(t.Cert) if err != nil { - return tls.Certificate{}, fmt.Errorf("could not load cert data: %v", err) + return fmt.Errorf("could not load cert data: %v", err) } cert, err := tls.X509KeyPair(certData, keyData) if err != nil { - return tls.Certificate{}, fmt.Errorf("could not parse certificate data: %v", err) + return fmt.Errorf("could not parse certificate data: %v", err) } - return cert, nil + t.cert = &cert + return nil +} + +func (t *tlsLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + if t.cert == nil { + return nil, fmt.Errorf("no certificate") + } + + return t.cert, nil +} + +func getCertificateLoader(opts *options.TLS) (*tlsLoader, error) { + l := &tlsLoader{ + TLS: opts, + } + + if err := l.LoadCert(); err != nil { + return nil, err + } + + return l, nil +} + +type reloadableTLSListener struct { + net.Listener + loader *tlsLoader +} + +func (rl reloadableTLSListener) Reload() error { + return rl.loader.LoadCert() } // getSecretValue wraps util.GetSecretValue so that we can return an error if no