diff --git a/cmd/controller/app/BUILD.bazel b/cmd/controller/app/BUILD.bazel index f4239ab87..25c6d55a2 100644 --- a/cmd/controller/app/BUILD.bazel +++ b/cmd/controller/app/BUILD.bazel @@ -53,6 +53,7 @@ go_library( "@io_k8s_sigs_gateway_api//pkg/client/clientset/versioned/scheme:go_default_library", "@io_k8s_sigs_gateway_api//pkg/client/informers/externalversions:go_default_library", "@io_k8s_utils//clock:go_default_library", + "@org_golang_x_sync//errgroup:go_default_library", ], ) diff --git a/cmd/controller/app/controller.go b/cmd/controller/app/controller.go index 548c2439c..9d7c363b7 100644 --- a/cmd/controller/app/controller.go +++ b/cmd/controller/app/controller.go @@ -19,10 +19,12 @@ package app import ( "context" "fmt" + "net" + "net/http" "os" - "sync" "time" + "golang.org/x/sync/errgroup" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/resource" @@ -65,10 +67,11 @@ const resyncPeriod = 10 * time.Hour func Run(opts *options.ControllerOptions, stopCh <-chan struct{}) { rootCtx := cmdutil.ContextWithStopCh(context.Background(), stopCh) + g, rootCtx := errgroup.WithContext(rootCtx) rootCtx = logf.NewContext(rootCtx, nil, "controller") log := logf.FromContext(rootCtx) - ctx, kubeCfg, err := buildControllerContext(rootCtx, stopCh, opts) + ctx, kubeCfg, err := buildControllerContext(rootCtx, opts) if err != nil { log.Error(err, "error building controller context", "options", opts) os.Exit(1) @@ -77,13 +80,32 @@ func Run(opts *options.ControllerOptions, stopCh <-chan struct{}) { enabledControllers := opts.EnabledControllers() log.Info(fmt.Sprintf("enabled controllers: %s", enabledControllers.List())) - metricsServer, err := ctx.Metrics.Start(opts.MetricsListenAddress, opts.EnablePprof) + ln, err := net.Listen("tcp", opts.MetricsListenAddress) if err != nil { log.Error(err, "failed to listen on prometheus address", "address", opts.MetricsListenAddress) os.Exit(1) } + server := ctx.Metrics.NewServer(ln, opts.EnablePprof) + + g.Go(func() error { + <-rootCtx.Done() + // allow a timeout for graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + return err + } + return nil + }) + g.Go(func() error { + log.WithValues("address", ln.Addr()).V(logf.InfoLevel).Info("listening for connections on") + if err := server.Serve(ln); err != http.ErrServerClosed { + return err + } + return nil + }) - var wg sync.WaitGroup run := func(_ context.Context) { for n, fn := range controller.Known() { log := log.WithValues("controller", n) @@ -100,33 +122,33 @@ func Run(opts *options.ControllerOptions, stopCh <-chan struct{}) { continue } - wg.Add(1) iface, err := fn(ctx) if err != nil { log.Error(err, "error starting controller") os.Exit(1) } - go func(n string, fn controller.Interface) { - defer wg.Done() - log.V(logf.InfoLevel).Info("starting controller") + + g.Go(func() error { + log.V(logf.InfoLevel).Info("starting controller", n) workers := 5 - err := fn.Run(workers, stopCh) - - if err != nil { - log.Error(err, "error starting controller") - os.Exit(1) - } - }(n, iface) + return iface.Run(workers, rootCtx.Done()) + }) } log.V(logf.DebugLevel).Info("starting shared informer factories") - ctx.SharedInformerFactory.Start(stopCh) - ctx.KubeSharedInformerFactory.Start(stopCh) - ctx.GWShared.Start(stopCh) - wg.Wait() + // TODO: we should wait for these informers to finish + ctx.SharedInformerFactory.Start(rootCtx.Done()) + ctx.KubeSharedInformerFactory.Start(rootCtx.Done()) + ctx.GWShared.Start(rootCtx.Done()) + + err := g.Wait() + if err != nil { + log.Error(err, "error starting controller") + os.Exit(1) + } log.V(logf.InfoLevel).Info("control loops exited") - ctx.Metrics.Shutdown(metricsServer) + os.Exit(0) } @@ -145,7 +167,7 @@ func Run(opts *options.ControllerOptions, stopCh <-chan struct{}) { startLeaderElection(rootCtx, opts, leaderElectionClient, ctx.Recorder, run) } -func buildControllerContext(ctx context.Context, stopCh <-chan struct{}, opts *options.ControllerOptions) (*controller.Context, *rest.Config, error) { +func buildControllerContext(ctx context.Context, opts *options.ControllerOptions) (*controller.Context, *rest.Config, error) { log := logf.FromContext(ctx, "build-context") // Load the users Kubernetes config kubeCfg, err := clientcmd.BuildConfigFromFlags(opts.APIServerHost, opts.Kubeconfig) @@ -238,7 +260,7 @@ func buildControllerContext(ctx context.Context, stopCh <-chan struct{}, opts *o return &controller.Context{ RootContext: ctx, - StopCh: stopCh, + StopCh: ctx.Done(), RESTConfig: kubeCfg, Client: cl, CMClient: intcl, diff --git a/pkg/controller/controller.go b/pkg/controller/controller.go index 1ea405440..d52ab705e 100644 --- a/pkg/controller/controller.go +++ b/pkg/controller/controller.go @@ -107,10 +107,10 @@ func (c *controller) Run(workers int, stopCh <-chan struct{}) error { var wg sync.WaitGroup for i := 0; i < workers; i++ { - // TODO (@munnerz): make time.Second duration configurable wg.Add(1) go func() { defer wg.Done() + // TODO (@munnerz): make time.Second duration configurable wait.Until(func() { c.worker(ctx) }, time.Second, stopCh) }() } diff --git a/pkg/webhook/authority/authority.go b/pkg/webhook/authority/authority.go index 36ddadae0..edd954468 100644 --- a/pkg/webhook/authority/authority.go +++ b/pkg/webhook/authority/authority.go @@ -210,6 +210,7 @@ func (d *DynamicAuthority) WatchRotation(stopCh <-chan struct{}) <-chan struct{} ch := make(chan struct{}, 1) d.watches = append(d.watches, ch) go func() { + defer close(ch) <-stopCh d.watchMutex.Lock() defer d.watchMutex.Unlock() diff --git a/pkg/webhook/server/BUILD.bazel b/pkg/webhook/server/BUILD.bazel index 8467ed363..6279c17cb 100644 --- a/pkg/webhook/server/BUILD.bazel +++ b/pkg/webhook/server/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/logs:go_default_library", + "//pkg/util:go_default_library", "//pkg/util/profiling:go_default_library", "//pkg/webhook/handlers:go_default_library", "//pkg/webhook/server/tls:go_default_library", @@ -24,6 +25,7 @@ go_library( "@io_k8s_apimachinery//pkg/util/runtime:go_default_library", "@io_k8s_component_base//cli/flag:go_default_library", "@io_k8s_sigs_controller_runtime//pkg/log:go_default_library", + "@org_golang_x_sync//errgroup:go_default_library", ], ) diff --git a/pkg/webhook/server/server.go b/pkg/webhook/server/server.go index 6c46c2d3e..2d1dc7327 100644 --- a/pkg/webhook/server/server.go +++ b/pkg/webhook/server/server.go @@ -27,6 +27,7 @@ import ( "time" "github.com/go-logr/logr" + "golang.org/x/sync/errgroup" admissionv1 "k8s.io/api/admission/v1" admissionv1beta1 "k8s.io/api/admission/v1beta1" apiextensionsinstall "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/install" @@ -41,10 +42,11 @@ import ( crlog "sigs.k8s.io/controller-runtime/pkg/log" logf "github.com/jetstack/cert-manager/pkg/logs" + "github.com/jetstack/cert-manager/pkg/util" "github.com/jetstack/cert-manager/pkg/util/profiling" "github.com/jetstack/cert-manager/pkg/webhook/handlers" servertls "github.com/jetstack/cert-manager/pkg/webhook/server/tls" - "github.com/jetstack/cert-manager/pkg/webhook/server/util" + webhookutil "github.com/jetstack/cert-manager/pkg/webhook/server/util" ) var ( @@ -125,21 +127,12 @@ func (s *Server) Run(stopCh <-chan struct{}) error { s.Log = crlog.NullLogger{} } - internalStopCh := make(chan struct{}) - // only close the internalStopCh if it hasn't already been closed - shutdown := false - defer func() { - if !shutdown { - close(internalStopCh) - } - }() - - var healthzChan <-chan error - var certSourceChan <-chan error + gctx := util.ContextWithStopCh(context.Background(), stopCh) + g, gctx := errgroup.WithContext(gctx) // if a HealthzAddr is provided, start the healthz listener if s.HealthzAddr != "" { - l, err := net.Listen("tcp", s.HealthzAddr) + healthzListener, err := net.Listen("tcp", s.HealthzAddr) if err != nil { return err } @@ -148,20 +141,43 @@ func (s *Server) Run(stopCh <-chan struct{}) error { mux.HandleFunc("/healthz", s.handleHealthz) mux.HandleFunc("/livez", s.handleLivez) s.Log.V(logf.InfoLevel).Info("listening for insecure healthz connections", "address", s.HealthzAddr) - healthzChan = s.startServer(l, internalStopCh, mux) + server := &http.Server{ + Handler: mux, + } + g.Go(func() error { + <-gctx.Done() + // allow a timeout for graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := server.Serve(healthzListener); err != http.ErrServerClosed { + return err + } + return nil + }) } // create a listener for actual webhook requests - l, err := net.Listen("tcp", s.ListenAddr) + listerner, err := net.Listen("tcp", s.ListenAddr) if err != nil { return err } - s.listener = l // wrap the listener with TLS if a CertificateSource is provided if s.CertificateSource != nil { s.Log.V(logf.InfoLevel).Info("listening for secure connections", "address", s.ListenAddr) - certSourceChan = s.startCertificateSource(internalStopCh) + g.Go(func() error { + if err := s.CertificateSource.Run(gctx.Done()); (err != nil) && !errors.Is(err, context.Canceled) { + return err + } + return nil + }) cipherSuites, err := ciphers.TLSCipherSuites(s.CipherSuites) if err != nil { return err @@ -170,7 +186,7 @@ func (s *Server) Run(stopCh <-chan struct{}) error { if err != nil { return err } - l = tls.NewListener(l, &tls.Config{ + listerner = tls.NewListener(listerner, &tls.Config{ GetCertificate: s.CertificateSource.GetCertificate, CipherSuites: cipherSuites, MinVersion: minVersion, @@ -180,6 +196,7 @@ func (s *Server) Run(stopCh <-chan struct{}) error { s.Log.V(logf.InfoLevel).Info("listening for insecure connections", "address", s.ListenAddr) } + s.listener = listerner mux := http.NewServeMux() mux.HandleFunc("/validate", s.handle(s.validate)) mux.HandleFunc("/mutate", s.handle(s.mutate)) @@ -188,31 +205,28 @@ func (s *Server) Run(stopCh <-chan struct{}) error { profiling.Install(mux) s.Log.V(logf.InfoLevel).Info("registered pprof handlers") } - listenerChan := s.startServer(l, internalStopCh, mux) - - if certSourceChan == nil { - certSourceChan = blockingChan(internalStopCh) - } - if healthzChan == nil { - healthzChan = blockingChan(internalStopCh) + server := &http.Server{ + Handler: mux, } + g.Go(func() error { + <-gctx.Done() + // allow a timeout for graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - select { - case err = <-healthzChan: - case err = <-certSourceChan: - case err = <-listenerChan: - case <-stopCh: - } + if err := server.Shutdown(ctx); err != nil { + return err + } + return nil + }) + g.Go(func() error { + if err := server.Serve(s.listener); err != http.ErrServerClosed { + return err + } + return nil + }) - close(internalStopCh) - shutdown = true - - s.Log.V(logf.DebugLevel).Info("waiting for server to shutdown") - waitForAll(healthzChan, certSourceChan, listenerChan) - - s.Log.V(logf.InfoLevel).Info("server shutdown successfully") - - return err + return g.Wait() } // Port returns the port number that the webhook listener is listening on @@ -227,67 +241,6 @@ func (s *Server) Port() (int, error) { return tcpAddr.Port, nil } -func (s *Server) startServer(l net.Listener, stopCh <-chan struct{}, handle http.Handler) <-chan error { - ch := make(chan error) - go func() { - defer close(ch) - - srv := &http.Server{ - Handler: handle, - } - select { - case err := <-channelWrapper(func() error { return srv.Serve(l) }): - ch <- err - case <-stopCh: - // allow a fixed 5s for graceful shutdown - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - if err := srv.Shutdown(ctx); err != nil { - s.Log.Error(err, "failed to gracefully shutdown http server") - ch <- err - } - s.Log.V(logf.DebugLevel).Info("shutdown HTTP server gracefully") - } - }() - return ch -} - -func (s *Server) startCertificateSource(stopCh <-chan struct{}) <-chan error { - fn := func() error { - return s.CertificateSource.Run(stopCh) - } - return channelWrapper(fn) -} - -func waitForAll(chs ...<-chan error) error { - for _, ch := range chs { - if err := <-ch; err != nil { - return fmt.Errorf("error waiting for goroutine to exit: %w", err) - } - } - return nil -} - -func channelWrapper(fn func() error) <-chan error { - ch := make(chan error) - go func() { - defer close(ch) - ch <- fn() - }() - return ch -} - -// blockingChan returns a 'no-op' error channel. -// When stopCh is closed, the error channel will also be closed. -func blockingChan(stopCh <-chan struct{}) <-chan error { - ch := make(chan error) - go func() { - defer close(ch) - <-stopCh - }() - return ch -} - func (s *Server) scheme() *runtime.Scheme { if s.Scheme == nil { return defaultScheme @@ -305,7 +258,7 @@ func (s *Server) validate(ctx context.Context, obj runtime.Object) (runtime.Obje return nil, errors.New("request is not of type apiextensions v1 or v1beta1") } review = &admissionv1.AdmissionReview{} - util.Convert_v1beta1_AdmissionReview_To_admission_AdmissionReview(reviewv1beta1, review) + webhookutil.Convert_v1beta1_AdmissionReview_To_admission_AdmissionReview(reviewv1beta1, review) } resp := s.ValidationWebhook.Validate(ctx, review.Request) review.Response = resp @@ -317,7 +270,7 @@ func (s *Server) validate(ctx context.Context, obj runtime.Object) (runtime.Obje // reply v1beta1 reviewv1beta1 := &admissionv1beta1.AdmissionReview{} - util.Convert_admission_AdmissionReview_To_v1beta1_AdmissionReview(review, reviewv1beta1) + webhookutil.Convert_admission_AdmissionReview_To_v1beta1_AdmissionReview(review, reviewv1beta1) return reviewv1beta1, nil } @@ -331,7 +284,7 @@ func (s *Server) mutate(ctx context.Context, obj runtime.Object) (runtime.Object return nil, errors.New("request is not of type apiextensions v1 or v1beta1") } review = &admissionv1.AdmissionReview{} - util.Convert_v1beta1_AdmissionReview_To_admission_AdmissionReview(reviewv1beta1, review) + webhookutil.Convert_v1beta1_AdmissionReview_To_admission_AdmissionReview(reviewv1beta1, review) } resp := s.MutationWebhook.Mutate(ctx, review.Request) review.Response = resp @@ -343,7 +296,7 @@ func (s *Server) mutate(ctx context.Context, obj runtime.Object) (runtime.Object // reply v1beta1 reviewv1beta1 := &admissionv1beta1.AdmissionReview{} - util.Convert_admission_AdmissionReview_To_v1beta1_AdmissionReview(review, reviewv1beta1) + webhookutil.Convert_admission_AdmissionReview_To_v1beta1_AdmissionReview(review, reviewv1beta1) return reviewv1beta1, nil } diff --git a/pkg/webhook/server/tls/dynamic_source.go b/pkg/webhook/server/tls/dynamic_source.go index 47cdf8d1e..1b25b56ef 100644 --- a/pkg/webhook/server/tls/dynamic_source.go +++ b/pkg/webhook/server/tls/dynamic_source.go @@ -17,6 +17,7 @@ limitations under the License. package tls import ( + "context" "crypto" "crypto/tls" "crypto/x509" @@ -49,7 +50,6 @@ type DynamicSource struct { Log logr.Logger cachedCertificate *tls.Certificate - nextRenew time.Time lock sync.Mutex } @@ -67,6 +67,8 @@ func (f *DynamicSource) Run(stopCh <-chan struct{}) error { authorityErrChan <- f.Authority.Run(stopCh) }() + nextRenewCh := make(chan time.Time, 1) + // initially fetch a certificate from the signing CA interval := time.Second if err := wait.PollUntil(interval, func() (done bool, err error) { @@ -78,18 +80,21 @@ func (f *DynamicSource) Run(stopCh <-chan struct{}) error { return true, fmt.Errorf("failed to run certificate authority: %w", err) } if !ok { - return true, fmt.Errorf("certificate authority stopped") + return true, context.Canceled } default: // this case avoids blocking if the authority is still running } - if err := f.regenerateCertificate(); err != nil { + if err := f.regenerateCertificate(nextRenewCh); err != nil { f.Log.Error(err, "Failed to generate initial serving certificate, retrying...", "interval", interval) return false, nil } return true, nil }, stopCh); err != nil { + // In case of an error, the stopCh is closed; wait for authorityErrChan to be closed too + <-authorityErrChan + return err } @@ -99,24 +104,35 @@ func (f *DynamicSource) Run(stopCh <-chan struct{}) error { ch := make(chan struct{}) go func() { defer close(ch) + + var renewMoment time.Time + select { + case renewMoment = <-nextRenewCh: + // We recevieved a renew moment + default: + // This should never happen + panic("Unreacheable") + } + for { - // exit if stopCh closes + timer := time.NewTimer(renewMoment.Sub(time.Now())) + defer timer.Stop() + select { case <-stopCh: return - default: - } - // regenerate the certificate if we have gone past the 'nextRenew' time - if time.Now().After(f.nextRenew) { + case <-timer.C: ch <- struct{}{} + case renewMoment = <-nextRenewCh: + // We recevieved a renew moment, next loop iteration will update the timer } - time.Sleep(time.Second * 5) } }() return ch }() + // check the current certificate every 10s in case it needs updating - return wait.PollImmediateUntil(time.Second*10, func() (done bool, err error) { + if err := wait.PollImmediateUntil(time.Second*10, func() (done bool, err error) { // regenerate the serving certificate if the root CA has been rotated select { // if the authority has stopped for whatever reason, exit and return the error @@ -125,15 +141,15 @@ func (f *DynamicSource) Run(stopCh <-chan struct{}) error { return true, fmt.Errorf("failed to run certificate authority: %w", err) } if !ok { - return true, fmt.Errorf("certificate authority stopped") + return true, context.Canceled } // trigger regeneration if the root CA has been rotated case _, ok := <-rotationChan: if !ok { - return true, fmt.Errorf("channel closed") + return true, context.Canceled } f.Log.V(logf.InfoLevel).Info("Detected root CA rotation - regenerating serving certificates") - if err := f.regenerateCertificate(); err != nil { + if err := f.regenerateCertificate(nextRenewCh); err != nil { f.Log.Error(err, "Failed to regenerate serving certificate") // Return an error here and stop the source running - this case should never // occur, and if it does, indicates some form of internal error. @@ -142,15 +158,26 @@ func (f *DynamicSource) Run(stopCh <-chan struct{}) error { // trigger regeneration if a renewal is required case <-renewalChan: f.Log.V(logf.InfoLevel).Info("Serving certificate requires renewal, regenerating") - if err := f.regenerateCertificate(); err != nil { + if err := f.regenerateCertificate(nextRenewCh); err != nil { f.Log.Error(err, "Failed to regenerate serving certificate") // Return an error here and stop the source running - this case should never // occur, and if it does, indicates some form of internal error. return false, err } + case <-stopCh: + return true, context.Canceled } return false, nil - }, stopCh) + }, stopCh); err != nil { + // In case of an error, the stopCh is closed; wait for all channels to close + <-authorityErrChan + <-rotationChan + <-renewalChan + + return err + } + + return nil } func (f *DynamicSource) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { @@ -168,7 +195,7 @@ func (f *DynamicSource) Healthy() bool { // regenerateCertificate will trigger the cached certificate and private key to // be regenerated by requesting a new certificate from the authority. -func (f *DynamicSource) regenerateCertificate() error { +func (f *DynamicSource) regenerateCertificate(nextRenew chan<- time.Time) error { f.Log.V(logf.DebugLevel).Info("Generating new ECDSA private key") pk, err := pki.GenerateECPrivateKey(384) if err != nil { @@ -192,15 +219,13 @@ func (f *DynamicSource) regenerateCertificate() error { f.Log.V(logf.DebugLevel).Info("Signed new serving certificate") - if err := f.updateCertificate(pk, cert); err != nil { + if err := f.updateCertificate(pk, cert, nextRenew); err != nil { return err } - - f.Log.V(logf.InfoLevel).Info("Updated serving TLS certificate") return nil } -func (f *DynamicSource) updateCertificate(pk crypto.Signer, cert *x509.Certificate) error { +func (f *DynamicSource) updateCertificate(pk crypto.Signer, cert *x509.Certificate, nextRenew chan<- time.Time) error { f.lock.Lock() defer f.lock.Unlock() @@ -222,6 +247,8 @@ func (f *DynamicSource) updateCertificate(pk crypto.Signer, cert *x509.Certifica f.cachedCertificate = &bundle certDuration := cert.NotAfter.Sub(cert.NotBefore) // renew the certificate 1/3 of the time before its expiry - f.nextRenew = cert.NotAfter.Add(certDuration / -3) + nextRenew <- cert.NotAfter.Add(certDuration / -3) + + f.Log.V(logf.InfoLevel).Info("Updated serving TLS certificate") return nil } diff --git a/test/integration/certificates/metrics_controller_test.go b/test/integration/certificates/metrics_controller_test.go index 2e2dc1328..6f0c748c0 100644 --- a/test/integration/certificates/metrics_controller_test.go +++ b/test/integration/certificates/metrics_controller_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io/ioutil" + "net" "net/http" "strings" "testing" @@ -64,14 +65,20 @@ func TestMetricsController(t *testing.T) { if err != nil { t.Fatal(err) } - server := ctx.Metrics.NewServer(ln, false) + server := metricsHandler.NewServer(ln, false) + + go func() { + if err := server.Serve(ln); err != http.ErrServerClosed { + t.Fatal(err) + } + }() defer func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { - return err + t.Fatal(err) } }()