Check for DNS propagation before accept authorization from ACME

This commit is contained in:
James Munnelly 2017-08-05 17:39:56 +01:00
parent 45a1ab2f2b
commit dc80101476
9 changed files with 329 additions and 27 deletions

View File

@ -33,6 +33,7 @@ func (c *controller) sync(crt *v1alpha1.Certificate) error {
return fmt.Errorf("error getting issuer implementation for issuer '%s': %s", issuerObj.Name, err.Error())
}
log.Printf("Preparing Issuer '%s/%s' and Certificate '%s/%s'", issuerObj.Namespace, issuerObj.Name, crt.Namespace, crt.Name)
// TODO: move this to after the certificate check to avoid unneeded authorization checks
err = i.Prepare(crt)
@ -40,6 +41,7 @@ func (c *controller) sync(crt *v1alpha1.Certificate) error {
return err
}
log.Printf("Finished preparing with Issuer '%s/%s' and Certificate '%s/%s'", issuerObj.Namespace, issuerObj.Name, crt.Namespace, crt.Name)
// step one: check if referenced secret exists, if not, trigger issue event
secret, err := c.secretLister.Secrets(crt.Namespace).Get(crt.Spec.SecretName)

View File

@ -1,6 +1,7 @@
package acme
import (
"context"
"fmt"
"k8s.io/client-go/informers"
@ -43,8 +44,9 @@ func New(issuer *v1alpha1.Issuer,
}
type solver interface {
Present(crt *v1alpha1.Certificate, domain, token, key string) error
CleanUp(crt *v1alpha1.Certificate, domain, token, key string) error
Present(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error
Wait(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error
CleanUp(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error
}
func (a *Acme) solverFor(challengeType string) (solver, error) {

View File

@ -8,13 +8,12 @@ import (
"os"
"time"
"github.com/xenolf/lego/acme"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/dns/v1"
"github.com/munnerz/cert-manager/pkg/issuer/acme/dns/util"
)
// DNSProvider is an implementation of the DNSProvider interface.
@ -100,8 +99,8 @@ func NewDNSProviderServiceAccountBytes(project string, saBytes []byte) (*DNSProv
}
// Present creates a TXT record to fulfil the dns-01 challenge.
func (c *DNSProvider) Present(domain, token, keyAuth string) error {
fqdn, value, ttl := acme.DNS01Record(domain, keyAuth)
func (c *DNSProvider) Present(domain, token, key string) error {
fqdn, value, ttl := util.DNS01Record(domain, key)
zone, err := c.getHostedZone(domain)
if err != nil {
@ -147,8 +146,8 @@ func (c *DNSProvider) Present(domain, token, keyAuth string) error {
}
// CleanUp removes the TXT record matching the specified parameters.
func (c *DNSProvider) CleanUp(domain, token, keyAuth string) error {
fqdn, _, _ := acme.DNS01Record(domain, keyAuth)
func (c *DNSProvider) CleanUp(domain, token, key string) error {
fqdn, _, _ := util.DNS01Record(domain, key)
zone, err := c.getHostedZone(domain)
if err != nil {
@ -180,7 +179,7 @@ func (c *DNSProvider) Timeout() (timeout, interval time.Duration) {
// getHostedZone returns the managed-zone
func (c *DNSProvider) getHostedZone(domain string) (string, error) {
authZone, err := acme.FindZoneByFqdn(acme.ToFqdn(domain), acme.RecursiveNameservers)
authZone, err := util.FindZoneByFqdn(util.ToFqdn(domain), util.RecursiveNameservers)
if err != nil {
return "", err
}

View File

@ -1,13 +1,17 @@
package dns
import (
"context"
"fmt"
"log"
"time"
"k8s.io/client-go/kubernetes"
corev1listers "k8s.io/client-go/listers/core/v1"
"github.com/munnerz/cert-manager/pkg/apis/certmanager/v1alpha1"
"github.com/munnerz/cert-manager/pkg/issuer/acme/dns/clouddns"
"github.com/munnerz/cert-manager/pkg/issuer/acme/dns/util"
)
const (
@ -17,6 +21,7 @@ const (
type solver interface {
Present(domain, token, key string) error
CleanUp(domain, token, key string) error
Timeout() (timeout, interval time.Duration)
}
type Solver struct {
@ -25,15 +30,59 @@ type Solver struct {
secretLister corev1listers.SecretLister
}
func (s *Solver) Present(crt *v1alpha1.Certificate, domain, token, key string) error {
func (s *Solver) Present(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error {
slv, err := s.solverFor(crt, domain)
if err != nil {
return err
}
log.Printf("presenting key: %s", key)
return slv.Present(domain, token, key)
}
func (s *Solver) CleanUp(crt *v1alpha1.Certificate, domain, token, key string) error {
func (s *Solver) Wait(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error {
slv, err := s.solverFor(crt, domain)
if err != nil {
return err
}
type boolErr struct {
bool
error
}
fqdn, value, ttl := util.DNS01Record(domain, key)
log.Printf("[%s] Checking DNS record propagation using %+v", domain, util.RecursiveNameservers)
timeout, interval := slv.Timeout()
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for {
select {
case r := <-func() <-chan boolErr {
out := make(chan boolErr, 1)
go func() {
ok, err := util.PreCheckDNS(fqdn, value)
out <- boolErr{ok, err}
}()
return out
}():
if r.bool {
// TODO: move this to somewhere else
// TODO: make this wait for whatever the record *was*, not is now
log.Printf("sleeping for dns record for '%s' ttl %ds before returning from Wait", fqdn, ttl)
time.Sleep(time.Second * time.Duration(ttl))
return nil
}
log.Printf("[%s] dns record not yet propegated", domain)
time.Sleep(interval)
case <-ctx.Done():
return ctx.Err()
}
}
}
func (s *Solver) CleanUp(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error {
slv, err := s.solverFor(crt, domain)
if err != nil {
return err
@ -58,16 +107,6 @@ func (s *Solver) solverFor(crt *v1alpha1.Certificate, domain string) (solver, er
var impl solver
switch {
case providerConfig.CloudDNS != nil:
if providerConfig.CloudDNS.ServiceAccount == "" {
impl, err = clouddns.NewDNSProviderCredentials(providerConfig.CloudDNS.Project)
if err != nil {
return nil, fmt.Errorf("error instantiating google clouddns challenge solver: %s", err.Error())
}
break
}
saSecret, err := s.secretLister.Secrets(s.issuer.Namespace).Get(providerConfig.CloudDNS.ServiceAccount)
if err != nil {
return nil, fmt.Errorf("error getting clouddns service account: %s", err.Error())

View File

@ -0,0 +1,9 @@
package util
import "fmt"
// DNS01Record returns a DNS record which will fulfill the `dns-01` challenge
// TODO: move this into a non-generic place by resolving import cycle in dns package
func DNS01Record(domain, value string) (string, string, int) {
return fmt.Sprintf("_acme-challenge.%s.", domain), value, 60
}

View File

@ -0,0 +1,239 @@
package util
import (
"fmt"
"log"
"net"
"strings"
"time"
"github.com/miekg/dns"
"golang.org/x/net/publicsuffix"
)
type preCheckDNSFunc func(fqdn, value string) (bool, error)
var (
// PreCheckDNS checks DNS propagation before notifying ACME that
// the DNS challenge is ready.
PreCheckDNS preCheckDNSFunc = checkDNSPropagation
fqdnToZone = map[string]string{}
)
const defaultResolvConf = "/etc/resolv.conf"
var defaultNameservers = []string{
"google-public-dns-a.google.com:53",
"google-public-dns-b.google.com:53",
}
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
// DNSTimeout is used to override the default DNS timeout of 10 seconds.
var DNSTimeout = 10 * time.Second
// getNameservers attempts to get systems nameservers before falling back to the defaults
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
if err != nil || len(config.Servers) == 0 {
return defaults
}
systemNameservers := []string{}
for _, server := range config.Servers {
// ensure all servers have a port number
if _, _, err := net.SplitHostPort(server); err != nil {
systemNameservers = append(systemNameservers, net.JoinHostPort(server, "53"))
} else {
systemNameservers = append(systemNameservers, server)
}
}
return systemNameservers
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func checkDNSPropagation(fqdn, value string) (bool, error) {
// Initial attempt to resolve at the recursive NS
r, err := dnsQuery(fqdn, dns.TypeTXT, RecursiveNameservers, true)
if err != nil {
return false, err
}
if r.Rcode == dns.RcodeSuccess {
// If we see a CNAME here then use the alias
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if cn.Hdr.Name == fqdn {
fqdn = cn.Target
break
}
}
}
}
authoritativeNss, err := lookupNameservers(fqdn)
if err != nil {
return false, err
}
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
if err != nil {
return false, err
}
if r.Rcode != dns.RcodeSuccess {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}
log.Printf("looking up txt record for fqdn '%s'", fqdn)
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
if strings.Join(txt.Txt, "") == value {
found = true
break
}
}
}
if !found {
return false, fmt.Errorf("NS %s did not return the expected TXT record", ns)
}
}
return true, nil
}
// dnsQuery will query a nameserver, iterating through the supplied servers as it retries
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
// Will retry the request based on the number of servers (n+1)
for i := 1; i <= len(nameservers)+1; i++ {
ns := nameservers[i%len(nameservers)]
udp := &dns.Client{Net: "udp", Timeout: DNSTimeout}
in, _, err = udp.Exchange(m, ns)
if err == dns.ErrTruncated {
tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
// If the TCP request suceeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
if err == nil {
break
}
}
return
}
// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string) ([]string, error) {
var authoritativeNss []string
zone, err := FindZoneByFqdn(fqdn, RecursiveNameservers)
if err != nil {
return nil, fmt.Errorf("Could not determine the zone: %v", err)
}
r, err := dnsQuery(zone, dns.TypeNS, RecursiveNameservers, true)
if err != nil {
return nil, err
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, fmt.Errorf("Could not determine authoritative nameservers")
}
// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the
// domain labels until the nameserver returns a SOA record in the answer section.
func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
// Do we have it cached?
if zone, ok := fqdnToZone[fqdn]; ok {
return zone, nil
}
labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
domain := fqdn[index:]
// Give up if we have reached the TLD
if isTLD(domain) {
break
}
in, err := dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
return "", err
}
// Any response code other than NOERROR and NXDOMAIN is treated as error
if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess {
return "", fmt.Errorf("Unexpected response code '%s' for %s",
dns.RcodeToString[in.Rcode], domain)
}
// Check if we got a SOA RR in the answer section
if in.Rcode == dns.RcodeSuccess {
for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok {
zone := soa.Hdr.Name
fqdnToZone[fqdn] = zone
return zone, nil
}
}
}
}
return "", fmt.Errorf("Could not find the start of authority")
}
func isTLD(domain string) bool {
publicsuffix, _ := publicsuffix.PublicSuffix(UnFqdn(domain))
if publicsuffix == UnFqdn(domain) {
return true
}
return false
}
// ClearFqdnCache clears the cache of fqdn to zone mappings. Primarily used in testing.
func ClearFqdnCache() {
fqdnToZone = map[string]string{}
}
// ToFqdn converts the name into a fqdn appending a trailing dot.
func ToFqdn(name string) string {
n := len(name)
if n == 0 || name[n-1] == '.' {
return name
}
return name + "."
}
// UnFqdn converts the fqdn into a name removing the trailing dot.
func UnFqdn(name string) string {
n := len(name)
if n != 0 && name[n-1] == '.' {
return name[:n-1]
}
return name
}

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"fmt"
"log"
"net/http"
@ -39,13 +40,18 @@ func NewSolver() *Solver {
return &Solver{}
}
func (s *Solver) Present(crt *v1alpha1.Certificate, domain, token, key string) error {
func (s *Solver) Present(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error {
solver.addChallenge(challenge{domain, token, key})
return nil
}
// todo
func (s *Solver) CleanUp(crt *v1alpha1.Certificate, domain, token, key string) error {
func (s *Solver) Wait(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error {
return nil
}
// todo
func (s *Solver) CleanUp(ctx context.Context, crt *v1alpha1.Certificate, domain, token, key string) error {
return nil
}

View File

@ -84,7 +84,7 @@ func (a *Acme) obtainCertificate(crt *v1alpha1.Certificate) (privateKeyPem []byt
return nil, nil, fmt.Errorf("error creating certificate request: %s", err)
}
certSlice, certUrl, err := cl.CreateCert(
certSlice, certURL, err := cl.CreateCert(
context.Background(),
csr,
0,
@ -99,7 +99,7 @@ func (a *Acme) obtainCertificate(crt *v1alpha1.Certificate) (privateKeyPem []byt
pem.Encode(certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: cert})
}
log.Printf("successfully got certificate: domains=%+v url=%s", domains, certUrl)
log.Printf("successfully got certificate: domains=%+v url=%s", domains, certURL)
return privateKeyPem, certBuffer.Bytes(), nil
}

View File

@ -141,11 +141,17 @@ func (a *Acme) prepare(crt *v1alpha1.Certificate) error {
}
log.Printf("presenting challenge for domain %s, token %s key %s", auth.domain, token, key)
err = solver.Present(crt, auth.domain, token, key)
err = solver.Present(context.Background(), crt, auth.domain, token, key)
if err != nil {
return fmt.Errorf("error presenting acme authorization for domain '%s': %s", auth.domain, err.Error())
}
log.Printf("waiting for key to be available to acme servers for domain %s", auth.domain)
err = solver.Wait(context.Background(), crt, auth.domain, token, key)
if err != nil {
return fmt.Errorf("error waiting for key to be available for domain '%s': %s", auth.domain, err.Error())
}
log.Printf("accepting %s challenge for domain %s", challengeType, auth.domain)
challenge, err = cl.Accept(context.Background(), challenge)
if err != nil {