From 5926a53706f6894a9ef84c0255d7e2cb03a9bdb1 Mon Sep 17 00:00:00 2001 From: James Munnelly Date: Sun, 6 Aug 2017 23:49:19 +0100 Subject: [PATCH] Refactor controller loop to only perform authorizations when issuing/renewing --- pkg/controller/certificates/controller.go | 2 +- pkg/controller/certificates/sync.go | 76 ++++++++++++++++------- pkg/scheduler/scheduler.go | 17 +++-- pkg/scheduler/scheduler_test.go | 70 +++++++++++++++++++++ 4 files changed, 137 insertions(+), 28 deletions(-) create mode 100644 pkg/scheduler/scheduler_test.go diff --git a/pkg/controller/certificates/controller.go b/pkg/controller/certificates/controller.go index 3d4797d1e..ef0fb8eb4 100644 --- a/pkg/controller/certificates/controller.go +++ b/pkg/controller/certificates/controller.go @@ -50,7 +50,7 @@ type controller struct { ingressLister extlisters.IngressLister queue workqueue.RateLimitingInterface - scheduledWorkQueue *scheduler.ScheduledWorkQueue + scheduledWorkQueue scheduler.ScheduledWorkQueue } // New returns a new Certificates controller. It sets up the informer handler diff --git a/pkg/controller/certificates/sync.go b/pkg/controller/certificates/sync.go index 30a284024..577900b3c 100644 --- a/pkg/controller/certificates/sync.go +++ b/pkg/controller/certificates/sync.go @@ -40,36 +40,39 @@ func (c *controller) sync(crt *v1alpha1.Certificate) (err error) { return fmt.Errorf("error getting issuer implementation for issuer '%s': %s", issuerObj.Name, err.Error()) } - log.Printf("Preparing Issuer '%s/%s' and Certificate '%s/%s'", issuerObj.Namespace, issuerObj.Name, crt.Namespace, crt.Name) - // TODO: move this to after the certificate check to avoid unneeded authorization checks - err = i.Prepare(crt) - - if err != nil { - return err - } - - log.Printf("Finished preparing with Issuer '%s/%s' and Certificate '%s/%s'", issuerObj.Namespace, issuerObj.Name, crt.Namespace, crt.Name) - - defer c.scheduleRenewal(crt) - - // step one: check if referenced secret exists, if not, trigger issue event + // grab existing certificate and validate private key cert, _, err := c.getCertificate(crt.Namespace, crt.Spec.SecretName) - if err != nil { - if k8sErrors.IsNotFound(err) || err == errInvalidCertificateData { - return c.issue(i, crt) - } + // if an error is returned, and that error is something other than + // IsNotFound or invalid data, then we should return the error. + if err != nil && !k8sErrors.IsNotFound(err) && err != errInvalidCertificateData { return err } - // step two: check if referenced secret is valid for listed domains. if not, return failure - if !util.EqualUnsorted(crt.Spec.Domains, cert.DNSNames) { - log.Printf("list of domains on certificate do not match domains in spec") + // as there is an existing certificate, or we may create one below, we will + // run scheduleRenewal to schedule a renewal if required at the end of + // execution. + defer c.scheduleRenewal(crt) + + // if the certificate was not found, or the certificate data is invalid, we + // should issue a new certificate + if k8sErrors.IsNotFound(err) || err == errInvalidCertificateData { return c.issue(i, crt) } + + // if the certificate is valid for a list of domains other than those + // listed in the certificate spec, we should re-issue the certificate + if !util.EqualUnsorted(crt.Spec.Domains, cert.DNSNames) { + return c.issue(i, crt) + } + + // calculate the amount of time until expiry durationUntilExpiry := cert.NotAfter.Sub(time.Now()) + // calculate how long until we should start attempting to renew the + // certificate renewIn := durationUntilExpiry - renewBefore - // step three: check if referenced secret is valid (after start & before expiry) + + // if we should being attempting to renew now, then trigger a renewal if renewIn <= 0 { return c.renew(i, crt) } @@ -77,6 +80,16 @@ func (c *controller) sync(crt *v1alpha1.Certificate) (err error) { return nil } +func needsRenew(cert *x509.Certificate) bool { + durationUntilExpiry := cert.NotAfter.Sub(time.Now()) + renewIn := durationUntilExpiry - renewBefore + // step three: check if referenced secret is valid (after start & before expiry) + if renewIn <= 0 { + return true + } + return false +} + func (c *controller) getCertificate(namespace, name string) (*x509.Certificate, *rsa.PrivateKey, error) { secret, err := c.client.CoreV1().Secrets(namespace).Get(name, metav1.GetOptions{}) @@ -145,9 +158,26 @@ func (c *controller) scheduleRenewal(crt *v1alpha1.Certificate) { c.scheduledWorkQueue.Add(key, renewIn) } +func (c *controller) prepare(issuer issuer.Interface, crt *v1alpha1.Certificate) error { + log.Printf("Preparing Certificate '%s/%s'", crt.Namespace, crt.Name) + // TODO: move this to after the certificate check to avoid unneeded authorization checks + err := issuer.Prepare(crt) + + if err != nil { + return err + } + + log.Printf("Finished preparing Certificate '%s/%s'", crt.Namespace, crt.Name) + return nil +} + // return an error on failure. If retrieval is succesful, the certificate data // and private key will be stored in the named secret func (c *controller) issue(issuer issuer.Interface, crt *v1alpha1.Certificate) error { + if err := c.prepare(issuer, crt); err != nil { + return err + } + log.Printf("[%s/%s] Issuing certificate...", crt.Namespace, crt.Name) key, cert, err := issuer.Issue(crt) if err != nil { @@ -178,6 +208,10 @@ func (c *controller) issue(issuer issuer.Interface, crt *v1alpha1.Certificate) e // return an error on failure. If renewal is succesful, the certificate data // and private key will be stored in the named secret func (c *controller) renew(issuer issuer.Interface, crt *v1alpha1.Certificate) error { + if err := c.prepare(issuer, crt); err != nil { + return err + } + log.Printf("[%s/%s] Renewing certificate...", crt.Namespace, crt.Name) key, cert, err := issuer.Renew(crt) if err != nil { diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 6f177e3ed..b3ff3c29a 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -7,17 +7,22 @@ import ( type ProcessFunc func(interface{}) -type ScheduledWorkQueue struct { +type ScheduledWorkQueue interface { + Add(interface{}, time.Duration) + Forget(interface{}) +} + +type scheduledWorkQueue struct { processFunc ProcessFunc work map[interface{}]*time.Timer workLock sync.Mutex } -func NewScheduledWorkQueue(processFunc ProcessFunc) *ScheduledWorkQueue { - return &ScheduledWorkQueue{processFunc, make(map[interface{}]*time.Timer), sync.Mutex{}} +func NewScheduledWorkQueue(processFunc ProcessFunc) ScheduledWorkQueue { + return &scheduledWorkQueue{processFunc, make(map[interface{}]*time.Timer), sync.Mutex{}} } -func (s *ScheduledWorkQueue) Add(obj interface{}, duration time.Duration) { +func (s *scheduledWorkQueue) Add(obj interface{}, duration time.Duration) { s.clearTimer(obj) s.work[obj] = time.AfterFunc(duration, func() { defer s.clearTimer(obj) @@ -25,11 +30,11 @@ func (s *ScheduledWorkQueue) Add(obj interface{}, duration time.Duration) { }) } -func (s *ScheduledWorkQueue) Forget(obj interface{}) { +func (s *scheduledWorkQueue) Forget(obj interface{}) { s.clearTimer(obj) } -func (s *ScheduledWorkQueue) clearTimer(obj interface{}) { +func (s *scheduledWorkQueue) clearTimer(obj interface{}) { s.workLock.Lock() defer s.workLock.Unlock() if timer, ok := s.work[obj]; ok { diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go new file mode 100644 index 000000000..556a362bf --- /dev/null +++ b/pkg/scheduler/scheduler_test.go @@ -0,0 +1,70 @@ +package scheduler + +import ( + "sync" + "testing" + "time" +) + +func TestAdd(t *testing.T) { + var wg sync.WaitGroup + type testT struct { + obj string + duration time.Duration + } + tests := []testT{ + {"test500", time.Millisecond * 500}, + {"test1000", time.Second * 1}, + {"test3000", time.Second * 3}, + } + for _, test := range tests { + wg.Add(1) + t.Run(test.obj, func(test testT) func(*testing.T) { + return func(t *testing.T) { + startTime := time.Now() + queue := NewScheduledWorkQueue(func(obj interface{}) { + defer wg.Done() + durationEarly := test.duration - time.Now().Sub(startTime) + if durationEarly > 0 { + t.Errorf("got queue item %.2f seconds too early", float64(durationEarly)/float64(time.Second)) + } + if obj != test.obj { + t.Errorf("expected obj '%+v' but got obj '%+v'", test.obj, obj) + } + }) + queue.Add(test.obj, test.duration) + } + }(test)) + } + + wg.Wait() +} + +func TestForget(t *testing.T) { + var wg sync.WaitGroup + type testT struct { + obj string + duration time.Duration + } + tests := []testT{ + {"test500", time.Millisecond * 500}, + {"test1000", time.Second * 1}, + {"test3000", time.Second * 3}, + } + for _, test := range tests { + wg.Add(1) + t.Run(test.obj, func(test testT) func(*testing.T) { + return func(t *testing.T) { + defer wg.Done() + queue := NewScheduledWorkQueue(func(obj interface{}) { + t.Errorf("scheduled function should never be called") + }) + queue.Add(test.obj, test.duration) + queue.Forget(test.obj) + time.Sleep(test.duration * 2) + } + }(test)) + } + + wg.Wait() +}