From dc801014768c472bc25fcec8ec9b807dd511ebc9 Mon Sep 17 00:00:00 2001 From: James Munnelly Date: Sat, 5 Aug 2017 17:39:56 +0100 Subject: [PATCH] Check for DNS propagation before accept authorization from ACME --- pkg/controller/certificates/sync.go | 2 + pkg/issuer/acme/acme.go | 6 +- pkg/issuer/acme/dns/clouddns/clouddns.go | 15 +- pkg/issuer/acme/dns/dns.go | 63 ++++-- pkg/issuer/acme/dns/util/dns.go | 9 + pkg/issuer/acme/dns/util/wait.go | 239 +++++++++++++++++++++++ pkg/issuer/acme/http/http.go | 10 +- pkg/issuer/acme/issue.go | 4 +- pkg/issuer/acme/prepare.go | 8 +- 9 files changed, 329 insertions(+), 27 deletions(-) create mode 100644 pkg/issuer/acme/dns/util/dns.go create mode 100644 pkg/issuer/acme/dns/util/wait.go diff --git a/pkg/controller/certificates/sync.go b/pkg/controller/certificates/sync.go index 095d318a8..d718591ed 100644 --- a/pkg/controller/certificates/sync.go +++ b/pkg/controller/certificates/sync.go @@ -33,6 +33,7 @@ func (c *controller) sync(crt *v1alpha1.Certificate) error { return fmt.Errorf("error getting issuer implementation for issuer '%s': %s", issuerObj.Name, err.Error()) } + log.Printf("Preparing Issuer '%s/%s' and Certificate '%s/%s'", issuerObj.Namespace, issuerObj.Name, crt.Namespace, crt.Name) // TODO: move this to after the certificate check to avoid unneeded authorization checks err = i.Prepare(crt) @@ -40,6 +41,7 @@ func (c *controller) sync(crt *v1alpha1.Certificate) error { return err } + log.Printf("Finished preparing with Issuer '%s/%s' and Certificate '%s/%s'", issuerObj.Namespace, issuerObj.Name, crt.Namespace, crt.Name) // step one: check if referenced secret exists, if not, trigger issue event secret, err := c.secretLister.Secrets(crt.Namespace).Get(crt.Spec.SecretName) diff --git a/pkg/issuer/acme/acme.go b/pkg/issuer/acme/acme.go index f7851f271..effb8c696 100644 --- a/pkg/issuer/acme/acme.go +++ b/pkg/issuer/acme/acme.go @@ -1,6 +1,7 @@ package acme import ( + "context" "fmt" "k8s.io/client-go/informers" @@ -43,8 +44,9 @@ func New(issuer *v1alpha1.Issuer, } type solver interface { - Present(crt *v1alpha1.Certificate, domain, token, key string) error - CleanUp(crt *v1alpha1.Certificate, domain, token, key string) error + Present(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error + Wait(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error + CleanUp(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error } func (a *Acme) solverFor(challengeType string) (solver, error) { diff --git a/pkg/issuer/acme/dns/clouddns/clouddns.go b/pkg/issuer/acme/dns/clouddns/clouddns.go index 355f3f9c6..0d81e79b1 100644 --- a/pkg/issuer/acme/dns/clouddns/clouddns.go +++ b/pkg/issuer/acme/dns/clouddns/clouddns.go @@ -8,13 +8,12 @@ import ( "os" "time" - "github.com/xenolf/lego/acme" - "golang.org/x/net/context" "golang.org/x/oauth2" "golang.org/x/oauth2/google" - "google.golang.org/api/dns/v1" + + "github.com/munnerz/cert-manager/pkg/issuer/acme/dns/util" ) // DNSProvider is an implementation of the DNSProvider interface. @@ -100,8 +99,8 @@ func NewDNSProviderServiceAccountBytes(project string, saBytes []byte) (*DNSProv } // Present creates a TXT record to fulfil the dns-01 challenge. -func (c *DNSProvider) Present(domain, token, keyAuth string) error { - fqdn, value, ttl := acme.DNS01Record(domain, keyAuth) +func (c *DNSProvider) Present(domain, token, key string) error { + fqdn, value, ttl := util.DNS01Record(domain, key) zone, err := c.getHostedZone(domain) if err != nil { @@ -147,8 +146,8 @@ func (c *DNSProvider) Present(domain, token, keyAuth string) error { } // CleanUp removes the TXT record matching the specified parameters. -func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error { - fqdn, _, _ := acme.DNS01Record(domain, keyAuth) +func (c *DNSProvider) CleanUp(domain, token, key string) error { + fqdn, _, _ := util.DNS01Record(domain, key) zone, err := c.getHostedZone(domain) if err != nil { @@ -180,7 +179,7 @@ func (c *DNSProvider) Timeout() (timeout, interval time.Duration) { // getHostedZone returns the managed-zone func (c *DNSProvider) getHostedZone(domain string) (string, error) { - authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers) + authZone, err := util.FindZoneByFqdn(util.ToFqdn(domain), util.RecursiveNameservers) if err != nil { return "", err } diff --git a/pkg/issuer/acme/dns/dns.go b/pkg/issuer/acme/dns/dns.go index 730552c11..744eb1e93 100644 --- a/pkg/issuer/acme/dns/dns.go +++ b/pkg/issuer/acme/dns/dns.go @@ -1,13 +1,17 @@ package dns import ( + "context" "fmt" + "log" + "time" "k8s.io/client-go/kubernetes" corev1listers "k8s.io/client-go/listers/core/v1" "github.com/munnerz/cert-manager/pkg/apis/certmanager/v1alpha1" "github.com/munnerz/cert-manager/pkg/issuer/acme/dns/clouddns" + "github.com/munnerz/cert-manager/pkg/issuer/acme/dns/util" ) const ( @@ -17,6 +21,7 @@ const ( type solver interface { Present(domain, token, key string) error CleanUp(domain, token, key string) error + Timeout() (timeout, interval time.Duration) } type Solver struct { @@ -25,15 +30,59 @@ type Solver struct { secretLister corev1listers.SecretLister } -func (s *Solver) Present(crt *v1alpha1.Certificate, domain, token, key string) error { +func (s *Solver) Present(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error { slv, err := s.solverFor(crt, domain) if err != nil { return err } + log.Printf("presenting key: %s", key) return slv.Present(domain, token, key) } -func (s *Solver) CleanUp(crt *v1alpha1.Certificate, domain, token, key string) error { +func (s *Solver) Wait(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error { + slv, err := s.solverFor(crt, domain) + if err != nil { + return err + } + + type boolErr struct { + bool + error + } + + fqdn, value, ttl := util.DNS01Record(domain, key) + + log.Printf("[%s] Checking DNS record propagation using %+v", domain, util.RecursiveNameservers) + + timeout, interval := slv.Timeout() + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + for { + select { + case r := <-func() <-chan boolErr { + out := make(chan boolErr, 1) + go func() { + ok, err := util.PreCheckDNS(fqdn, value) + out <- boolErr{ok, err} + }() + return out + }(): + if r.bool { + // TODO: move this to somewhere else + // TODO: make this wait for whatever the record *was*, not is now + log.Printf("sleeping for dns record for '%s' ttl %ds before returning from Wait", fqdn, ttl) + time.Sleep(time.Second * time.Duration(ttl)) + return nil + } + log.Printf("[%s] dns record not yet propegated", domain) + time.Sleep(interval) + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (s *Solver) CleanUp(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error { slv, err := s.solverFor(crt, domain) if err != nil { return err @@ -58,16 +107,6 @@ func (s *Solver) solverFor(crt *v1alpha1.Certificate, domain string) (solver, er var impl solver switch { case providerConfig.CloudDNS != nil: - if providerConfig.CloudDNS.ServiceAccount == "" { - impl, err = clouddns.NewDNSProviderCredentials(providerConfig.CloudDNS.Project) - - if err != nil { - return nil, fmt.Errorf("error instantiating google clouddns challenge solver: %s", err.Error()) - } - - break - } - saSecret, err := s.secretLister.Secrets(s.issuer.Namespace).Get(providerConfig.CloudDNS.ServiceAccount) if err != nil { return nil, fmt.Errorf("error getting clouddns service account: %s", err.Error()) diff --git a/pkg/issuer/acme/dns/util/dns.go b/pkg/issuer/acme/dns/util/dns.go new file mode 100644 index 000000000..03a94e8a2 --- /dev/null +++ b/pkg/issuer/acme/dns/util/dns.go @@ -0,0 +1,9 @@ +package util + +import "fmt" + +// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge +// TODO: move this into a non-generic place by resolving import cycle in dns package +func DNS01Record(domain, value string) (string, string, int) { + return fmt.Sprintf("_acme-challenge.%s.", domain), value, 60 +} diff --git a/pkg/issuer/acme/dns/util/wait.go b/pkg/issuer/acme/dns/util/wait.go new file mode 100644 index 000000000..1f0d53a9d --- /dev/null +++ b/pkg/issuer/acme/dns/util/wait.go @@ -0,0 +1,239 @@ +package util + +import ( + "fmt" + "log" + "net" + "strings" + "time" + + "github.com/miekg/dns" + "golang.org/x/net/publicsuffix" +) + +type preCheckDNSFunc func(fqdn, value string) (bool, error) + +var ( + // PreCheckDNS checks DNS propagation before notifying ACME that + // the DNS challenge is ready. + PreCheckDNS preCheckDNSFunc = checkDNSPropagation + fqdnToZone = map[string]string{} +) + +const defaultResolvConf = "/etc/resolv.conf" + +var defaultNameservers = []string{ + "google-public-dns-a.google.com:53", + "google-public-dns-b.google.com:53", +} + +var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers) + +// DNSTimeout is used to override the default DNS timeout of 10 seconds. +var DNSTimeout = 10 * time.Second + +// getNameservers attempts to get systems nameservers before falling back to the defaults +func getNameservers(path string, defaults []string) []string { + config, err := dns.ClientConfigFromFile(path) + if err != nil || len(config.Servers) == 0 { + return defaults + } + + systemNameservers := []string{} + for _, server := range config.Servers { + // ensure all servers have a port number + if _, _, err := net.SplitHostPort(server); err != nil { + systemNameservers = append(systemNameservers, net.JoinHostPort(server, "53")) + } else { + systemNameservers = append(systemNameservers, server) + } + } + return systemNameservers +} + +// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers. +func checkDNSPropagation(fqdn, value string) (bool, error) { + // Initial attempt to resolve at the recursive NS + r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameservers, true) + if err != nil { + return false, err + } + if r.Rcode == dns.RcodeSuccess { + // If we see a CNAME here then use the alias + for _, rr := range r.Answer { + if cn, ok := rr.(*dns.CNAME); ok { + if cn.Hdr.Name == fqdn { + fqdn = cn.Target + break + } + } + } + } + + authoritativeNss, err := lookupNameservers(fqdn) + if err != nil { + return false, err + } + + return checkAuthoritativeNss(fqdn, value, authoritativeNss) +} + +// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record. +func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) { + for _, ns := range nameservers { + r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false) + if err != nil { + return false, err + } + + if r.Rcode != dns.RcodeSuccess { + return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn) + } + + log.Printf("looking up txt record for fqdn '%s'", fqdn) + var found bool + for _, rr := range r.Answer { + if txt, ok := rr.(*dns.TXT); ok { + if strings.Join(txt.Txt, "") == value { + found = true + break + } + } + } + + if !found { + return false, fmt.Errorf("NS %s did not return the expected TXT record", ns) + } + } + + return true, nil +} + +// dnsQuery will query a nameserver, iterating through the supplied servers as it retries +// The nameserver should include a port, to facilitate testing where we talk to a mock dns server. +func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) { + m := new(dns.Msg) + m.SetQuestion(fqdn, rtype) + m.SetEdns0(4096, false) + + if !recursive { + m.RecursionDesired = false + } + + // Will retry the request based on the number of servers (n+1) + for i := 1; i <= len(nameservers)+1; i++ { + ns := nameservers[i%len(nameservers)] + udp := &dns.Client{Net: "udp", Timeout: DNSTimeout} + in, _, err = udp.Exchange(m, ns) + + if err == dns.ErrTruncated { + tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout} + // If the TCP request suceeds, the err will reset to nil + in, _, err = tcp.Exchange(m, ns) + } + + if err == nil { + break + } + } + return +} + +// lookupNameservers returns the authoritative nameservers for the given fqdn. +func lookupNameservers(fqdn string) ([]string, error) { + var authoritativeNss []string + + zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers) + if err != nil { + return nil, fmt.Errorf("Could not determine the zone: %v", err) + } + + r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true) + if err != nil { + return nil, err + } + + for _, rr := range r.Answer { + if ns, ok := rr.(*dns.NS); ok { + authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns)) + } + } + + if len(authoritativeNss) > 0 { + return authoritativeNss, nil + } + return nil, fmt.Errorf("Could not determine authoritative nameservers") +} + +// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the +// domain labels until the nameserver returns a SOA record in the answer section. +func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) { + // Do we have it cached? + if zone, ok := fqdnToZone[fqdn]; ok { + return zone, nil + } + + labelIndexes := dns.Split(fqdn) + for _, index := range labelIndexes { + domain := fqdn[index:] + // Give up if we have reached the TLD + if isTLD(domain) { + break + } + + in, err := dnsQuery(domain, dns.TypeSOA, nameservers, true) + if err != nil { + return "", err + } + + // Any response code other than NOERROR and NXDOMAIN is treated as error + if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess { + return "", fmt.Errorf("Unexpected response code '%s' for %s", + dns.RcodeToString[in.Rcode], domain) + } + + // Check if we got a SOA RR in the answer section + if in.Rcode == dns.RcodeSuccess { + for _, ans := range in.Answer { + if soa, ok := ans.(*dns.SOA); ok { + zone := soa.Hdr.Name + fqdnToZone[fqdn] = zone + return zone, nil + } + } + } + } + + return "", fmt.Errorf("Could not find the start of authority") +} + +func isTLD(domain string) bool { + publicsuffix, _ := publicsuffix.PublicSuffix(UnFqdn(domain)) + if publicsuffix == UnFqdn(domain) { + return true + } + return false +} + +// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing. +func ClearFqdnCache() { + fqdnToZone = map[string]string{} +} + +// ToFqdn converts the name into a fqdn appending a trailing dot. +func ToFqdn(name string) string { + n := len(name) + if n == 0 || name[n-1] == '.' { + return name + } + return name + "." +} + +// UnFqdn converts the fqdn into a name removing the trailing dot. +func UnFqdn(name string) string { + n := len(name) + if n != 0 && name[n-1] == '.' { + return name[:n-1] + } + return name +} diff --git a/pkg/issuer/acme/http/http.go b/pkg/issuer/acme/http/http.go index 16595a971..a9c247d74 100644 --- a/pkg/issuer/acme/http/http.go +++ b/pkg/issuer/acme/http/http.go @@ -1,6 +1,7 @@ package http import ( + "context" "fmt" "log" "net/http" @@ -39,13 +40,18 @@ func NewSolver() *Solver { return &Solver{} } -func (s *Solver) Present(crt *v1alpha1.Certificate, domain, token, key string) error { +func (s *Solver) Present(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error { solver.addChallenge(challenge{domain, token, key}) return nil } // todo -func (s *Solver) CleanUp(crt *v1alpha1.Certificate, domain, token, key string) error { +func (s *Solver) Wait(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error { + return nil +} + +// todo +func (s *Solver) CleanUp(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error { return nil } diff --git a/pkg/issuer/acme/issue.go b/pkg/issuer/acme/issue.go index 31b5d2039..c5990afa5 100644 --- a/pkg/issuer/acme/issue.go +++ b/pkg/issuer/acme/issue.go @@ -84,7 +84,7 @@ func (a *Acme) obtainCertificate(crt *v1alpha1.Certificate) (privateKeyPem []byt return nil, nil, fmt.Errorf("error creating certificate request: %s", err) } - certSlice, certUrl, err := cl.CreateCert( + certSlice, certURL, err := cl.CreateCert( context.Background(), csr, 0, @@ -99,7 +99,7 @@ func (a *Acme) obtainCertificate(crt *v1alpha1.Certificate) (privateKeyPem []byt pem.Encode(certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: cert}) } - log.Printf("successfully got certificate: domains=%+v url=%s", domains, certUrl) + log.Printf("successfully got certificate: domains=%+v url=%s", domains, certURL) return privateKeyPem, certBuffer.Bytes(), nil } diff --git a/pkg/issuer/acme/prepare.go b/pkg/issuer/acme/prepare.go index c8cf56970..0b357d078 100644 --- a/pkg/issuer/acme/prepare.go +++ b/pkg/issuer/acme/prepare.go @@ -141,11 +141,17 @@ func (a *Acme) prepare(crt *v1alpha1.Certificate) error { } log.Printf("presenting challenge for domain %s, token %s key %s", auth.domain, token, key) - err = solver.Present(crt, auth.domain, token, key) + err = solver.Present(context.Background(), crt, auth.domain, token, key) if err != nil { return fmt.Errorf("error presenting acme authorization for domain '%s': %s", auth.domain, err.Error()) } + log.Printf("waiting for key to be available to acme servers for domain %s", auth.domain) + err = solver.Wait(context.Background(), crt, auth.domain, token, key) + if err != nil { + return fmt.Errorf("error waiting for key to be available for domain '%s': %s", auth.domain, err.Error()) + } + log.Printf("accepting %s challenge for domain %s", challengeType, auth.domain) challenge, err = cl.Accept(context.Background(), challenge) if err != nil {