cert-manager/pkg/issuer/acme/dns/util/wait.go
Thomas Miller dacd0b45cb Control authoritative dns01 server check.
Adds cmd flag for controlling if authoritative dns servers are used to
check RR propagation or just normal resolvers.

This change is added so that constrained enviornments can control more
aspects of DNS queries performed.

- Applying PR feedback

Signed-off-by: Thomas Miller <thomas@tlm.id.au>
2019-01-12 20:17:28 +10:00

298 lines
7.6 KiB
Go

// +skip_license_check
/*
This file contains portions of code directly taken from the 'xenolf/lego' project.
A copy of the license for this code can be found in the file named LICENSE in
this directory.
*/
package util
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/golang/glog"
"github.com/miekg/dns"
)
type preCheckDNSFunc func(fqdn, value string, nameservers []string,
useAuthoritative bool) (bool, error)
var (
// PreCheckDNS checks DNS propagation before notifying ACME that
// the DNS challenge is ready.
PreCheckDNS preCheckDNSFunc = checkDNSPropagation
fqdnToZoneLock sync.RWMutex
fqdnToZone = map[string]string{}
)
const defaultResolvConf = "/etc/resolv.conf"
var defaultNameservers = []string{
"8.8.8.8:53",
"8.8.4.4:53",
}
var RecursiveNameservers = getNameservers(defaultResolvConf, defaultNameservers)
// DNSTimeout is used to override the default DNS timeout of 10 seconds.
var DNSTimeout = 10 * time.Second
// getNameservers attempts to get systems nameservers before falling back to the defaults
func getNameservers(path string, defaults []string) []string {
config, err := dns.ClientConfigFromFile(path)
if err != nil || len(config.Servers) == 0 {
return defaults
}
systemNameservers := []string{}
for _, server := range config.Servers {
// ensure all servers have a port number
if _, _, err := net.SplitHostPort(server); err != nil {
systemNameservers = append(systemNameservers, net.JoinHostPort(server, "53"))
} else {
systemNameservers = append(systemNameservers, server)
}
}
return systemNameservers
}
// Update FQDN with CNAME if any
func updateDomainWithCName(r *dns.Msg, fqdn string) string {
for _, rr := range r.Answer {
if cn, ok := rr.(*dns.CNAME); ok {
if cn.Hdr.Name == fqdn {
glog.Infof("Updating FQDN: %s with it's CNAME: %s", fqdn, cn.Target)
fqdn = cn.Target
break
}
}
}
return fqdn
}
// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func checkDNSPropagation(fqdn, value string, nameservers []string,
useAuthoritative bool) (bool, error) {
// Initial attempt to resolve at the recursive NS
r, err := dnsQuery(fqdn, dns.TypeTXT, nameservers, true)
if err != nil {
return false, err
}
if r.Rcode == dns.RcodeSuccess {
fqdn = updateDomainWithCName(r, fqdn)
}
if !useAuthoritative {
return checkAuthoritativeNss(fqdn, value, nameservers)
}
authoritativeNss, err := lookupNameservers(fqdn, nameservers)
if err != nil {
return false, err
}
for i, ans := range authoritativeNss {
authoritativeNss[i] = net.JoinHostPort(ans, "53")
}
return checkAuthoritativeNss(fqdn, value, authoritativeNss)
}
// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{ns}, true)
if err != nil {
return false, err
}
// NXDomain response is not really an error, just waiting for propagation to happen
if !(r.Rcode == dns.RcodeSuccess || r.Rcode == dns.RcodeNameError) {
return false, fmt.Errorf("NS %s returned %s for %s", ns, dns.RcodeToString[r.Rcode], fqdn)
}
glog.V(6).Infof("Looking up TXT records for %q", fqdn)
var found bool
for _, rr := range r.Answer {
if txt, ok := rr.(*dns.TXT); ok {
if strings.Join(txt.Txt, "") == value {
found = true
break
}
}
}
if !found {
return false, nil
}
}
return true, nil
}
// dnsQuery will query a nameserver, iterating through the supplied servers as it retries
// The nameserver should include a port, to facilitate testing where we talk to a mock dns server.
func dnsQuery(fqdn string, rtype uint16, nameservers []string, recursive bool) (in *dns.Msg, err error) {
m := new(dns.Msg)
m.SetQuestion(fqdn, rtype)
m.SetEdns0(4096, false)
if !recursive {
m.RecursionDesired = false
}
// Will retry the request based on the number of servers (n+1)
for i := 1; i <= len(nameservers)+1; i++ {
ns := nameservers[i%len(nameservers)]
udp := &dns.Client{Net: "udp", Timeout: DNSTimeout}
in, _, err = udp.Exchange(m, ns)
if err == dns.ErrTruncated ||
(err != nil && strings.HasPrefix(err.Error(), "read udp") && strings.HasSuffix(err.Error(), "i/o timeout")) {
glog.V(6).Infof("UDP dns lookup failed, retrying with TCP: %v", err)
tcp := &dns.Client{Net: "tcp", Timeout: DNSTimeout}
// If the TCP request succeeds, the err will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
if err == nil {
break
}
}
return
}
// lookupNameservers returns the authoritative nameservers for the given fqdn.
func lookupNameservers(fqdn string, nameservers []string) ([]string, error) {
var authoritativeNss []string
zone, err := FindZoneByFqdn(fqdn, nameservers)
if err != nil {
return nil, fmt.Errorf("Could not determine the zone: %v", err)
}
r, err := dnsQuery(zone, dns.TypeNS, nameservers, true)
if err != nil {
return nil, err
}
for _, rr := range r.Answer {
if ns, ok := rr.(*dns.NS); ok {
authoritativeNss = append(authoritativeNss, strings.ToLower(ns.Ns))
}
}
if len(authoritativeNss) > 0 {
return authoritativeNss, nil
}
return nil, fmt.Errorf("Could not determine authoritative nameservers")
}
// FindZoneByFqdn determines the zone apex for the given fqdn by recursing up the
// domain labels until the nameserver returns a SOA record in the answer section.
func FindZoneByFqdn(fqdn string, nameservers []string) (string, error) {
fqdnToZoneLock.RLock()
// Do we have it cached?
if zone, ok := fqdnToZone[fqdn]; ok {
fqdnToZoneLock.RUnlock()
return zone, nil
}
fqdnToZoneLock.RUnlock()
labelIndexes := dns.Split(fqdn)
for _, index := range labelIndexes {
domain := fqdn[index:]
in, err := dnsQuery(domain, dns.TypeSOA, nameservers, true)
if err != nil {
return "", err
}
// Any response code other than NOERROR and NXDOMAIN is treated as error
if in.Rcode != dns.RcodeNameError && in.Rcode != dns.RcodeSuccess {
return "", fmt.Errorf("Unexpected response code '%s' for %s",
dns.RcodeToString[in.Rcode], domain)
}
// Check if we got a SOA RR in the answer section
if in.Rcode == dns.RcodeSuccess {
// CNAME records cannot/should not exist at the root of a zone.
// So we skip a domain when a CNAME is found.
if dnsMsgContainsCNAME(in) {
continue
}
for _, ans := range in.Answer {
if soa, ok := ans.(*dns.SOA); ok {
fqdnToZoneLock.Lock()
defer fqdnToZoneLock.Unlock()
zone := soa.Hdr.Name
fqdnToZone[fqdn] = zone
return zone, nil
}
}
}
}
return "", fmt.Errorf("Could not find the start of authority")
}
// dnsMsgContainsCNAME checks for a CNAME answer in msg
func dnsMsgContainsCNAME(msg *dns.Msg) bool {
for _, ans := range msg.Answer {
if _, ok := ans.(*dns.CNAME); ok {
return true
}
}
return false
}
// ToFqdn converts the name into a fqdn appending a trailing dot.
func ToFqdn(name string) string {
n := len(name)
if n == 0 || name[n-1] == '.' {
return name
}
return name + "."
}
// UnFqdn converts the fqdn into a name removing the trailing dot.
func UnFqdn(name string) string {
n := len(name)
if n != 0 && name[n-1] == '.' {
return name[:n-1]
}
return name
}
// WaitFor polls the given function 'f', once every 'interval', up to 'timeout'.
func WaitFor(timeout, interval time.Duration, f func() (bool, error)) error {
var lastErr string
timeup := time.After(timeout)
for {
select {
case <-timeup:
return fmt.Errorf("Time limit exceeded. Last error: %s", lastErr)
default:
}
stop, err := f()
if stop {
return nil
}
if err != nil {
lastErr = err.Error()
}
time.Sleep(interval)
}
}