diff --git a/cmd/controller/app/controller.go b/cmd/controller/app/controller.go index 39be1e69b..4e131345d 100644 --- a/cmd/controller/app/controller.go +++ b/cmd/controller/app/controller.go @@ -66,7 +66,7 @@ const controllerAgentName = "cert-manager" // This sets the informer's resync period to 10 hours // following the controller-runtime defaults -//and following discussion: https://github.com/kubernetes-sigs/controller-runtime/pull/88#issuecomment-408500629 +// and following discussion: https://github.com/kubernetes-sigs/controller-runtime/pull/88#issuecomment-408500629 const resyncPeriod = 10 * time.Hour func Run(opts *options.ControllerOptions, stopCh <-chan struct{}) error { @@ -360,6 +360,7 @@ func buildControllerContext(ctx context.Context, opts *options.ControllerOptions HTTP01SolverResourceRequestMemory: HTTP01SolverResourceRequestMemory, HTTP01SolverResourceLimitsCPU: HTTP01SolverResourceLimitsCPU, HTTP01SolverResourceLimitsMemory: HTTP01SolverResourceLimitsMemory, + HTTP01SolverNameservers: opts.ACMEHTTP01SolverNameservers, DNS01CheckAuthoritative: !opts.DNS01RecursiveNameserversOnly, DNS01Nameservers: nameservers, AccountRegistry: acmeAccountRegistry, diff --git a/cmd/controller/app/options/options.go b/cmd/controller/app/options/options.go index c4e793823..06fdc608a 100644 --- a/cmd/controller/app/options/options.go +++ b/cmd/controller/app/options/options.go @@ -79,6 +79,8 @@ type ControllerOptions struct { ACMEHTTP01SolverResourceRequestMemory string ACMEHTTP01SolverResourceLimitsCPU string ACMEHTTP01SolverResourceLimitsMemory string + // Allows specifying a list of custom nameservers to perform HTTP01 checks on. + ACMEHTTP01SolverNameservers []string ClusterIssuerAmbientCredentials bool IssuerAmbientCredentials bool @@ -232,6 +234,7 @@ func NewControllerOptions() *ControllerOptions { DefaultIssuerKind: defaultTLSACMEIssuerKind, DefaultIssuerGroup: defaultTLSACMEIssuerGroup, DefaultAutoCertificateAnnotations: defaultAutoCertificateAnnotations, + ACMEHTTP01SolverNameservers: []string{}, DNS01RecursiveNameservers: []string{}, DNS01RecursiveNameserversOnly: defaultDNS01RecursiveNameserversOnly, EnableCertificateOwnerRef: defaultEnableCertificateOwnerRef, @@ -298,6 +301,11 @@ func (s *ControllerOptions) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&s.ACMEHTTP01SolverResourceLimitsMemory, "acme-http01-solver-resource-limits-memory", defaultACMEHTTP01SolverResourceLimitsMemory, ""+ "Defines the resource limits Memory size when spawning new ACME HTTP01 challenge solver pods.") + fs.StringSliceVar(&s.ACMEHTTP01SolverNameservers, "acme-http01-solver-nameservers", + []string{}, "A list of comma separated dns server endpoints used for "+ + "ACME HTTP01 check requests. This should be a list containing host and "+ + "port, for example 8.8.8.8:53,8.8.4.4:53") + fs.BoolVar(&s.ClusterIssuerAmbientCredentials, "cluster-issuer-ambient-credentials", defaultClusterIssuerAmbientCredentials, ""+ "Whether a cluster-issuer may make use of ambient credentials for issuers. 'Ambient Credentials' are credentials drawn from the environment, metadata services, or local files which are not explicitly configured in the ClusterIssuer API object. "+ "When this flag is enabled, the following sources for credentials are also used: "+ @@ -369,7 +377,7 @@ func (o *ControllerOptions) Validate() error { return fmt.Errorf("invalid value for kube-api-burst: %v must be higher or equal to kube-api-qps: %v", o.KubernetesAPIQPS, o.KubernetesAPIQPS) } - for _, server := range o.DNS01RecursiveNameservers { + for _, server := range append(o.DNS01RecursiveNameservers, o.ACMEHTTP01SolverNameservers...) { // ensure all servers have a port number _, _, err := net.SplitHostPort(server) if err != nil { diff --git a/pkg/controller/context.go b/pkg/controller/context.go index 0319b2476..b0db3dc88 100644 --- a/pkg/controller/context.go +++ b/pkg/controller/context.go @@ -123,6 +123,10 @@ type ACMEOptions struct { // HTTP01SolverResourceLimitsMemory defines the ACME pod's resource limits Memory size HTTP01SolverResourceLimitsMemory resource.Quantity + // HTTP01SolverNameservers is a list of nameservers to use when performing self-checks + // for ACME HTTP01 validations. + HTTP01SolverNameservers []string + // DNS01CheckAuthoritative is a flag for controlling if auth nss are used // for checking propagation of an RR. This is the ideal scenario DNS01CheckAuthoritative bool diff --git a/pkg/issuer/acme/http/BUILD.bazel b/pkg/issuer/acme/http/BUILD.bazel index fcb747ccc..893775d58 100644 --- a/pkg/issuer/acme/http/BUILD.bazel +++ b/pkg/issuer/acme/http/BUILD.bazel @@ -48,8 +48,10 @@ go_test( embed = [":go_default_library"], deps = [ "//pkg/apis/acme/v1:go_default_library", + "//pkg/controller:go_default_library", "//pkg/controller/test:go_default_library", "//test/unit/gen:go_default_library", + "@com_github_miekg_dns//:go_default_library", "@io_k8s_api//core/v1:go_default_library", "@io_k8s_api//networking/v1:go_default_library", "@io_k8s_apimachinery//pkg/api/errors:go_default_library", diff --git a/pkg/issuer/acme/http/http.go b/pkg/issuer/acme/http/http.go index 2eee2bbbe..f3dd253ff 100644 --- a/pkg/issuer/acme/http/http.go +++ b/pkg/issuer/acme/http/http.go @@ -70,7 +70,7 @@ type Solver struct { requiredPasses int } -type reachabilityTest func(ctx context.Context, url *url.URL, key string) error +type reachabilityTest func(ctx context.Context, url *url.URL, key string, dnsServers []string) error // NewSolver returns a new ACME HTTP01 solver for the given *controller.Context. func NewSolver(ctx *controller.Context) (*Solver, error) { @@ -172,7 +172,7 @@ func (s *Solver) Check(ctx context.Context, issuer v1.GenericIssuer, ch *cmacme. log.V(logf.DebugLevel).Info("running self check multiple times to ensure challenge has propagated", "required_passes", s.requiredPasses) for i := 0; i < s.requiredPasses; i++ { - err := s.testReachability(ctx, url, ch.Spec.Key) + err := s.testReachability(ctx, url, ch.Spec.Key, s.HTTP01SolverNameservers) if err != nil { return err } @@ -211,7 +211,7 @@ func (s *Solver) buildChallengeUrl(ch *cmacme.Challenge) *url.URL { // testReachability will attempt to connect to the 'domain' with 'path' and // check if the returned body equals 'key' -func testReachability(ctx context.Context, url *url.URL, key string) error { +func testReachability(ctx context.Context, url *url.URL, key string, dnsServers []string) error { log := logf.FromContext(ctx) log.V(logf.DebugLevel).Info("performing HTTP01 reachability check") @@ -269,6 +269,28 @@ func testReachability(ctx context.Context, url *url.URL, key string) error { }, } + if len(dnsServers) != 0 { + transport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + // we need to increment a counter to iterate through the dns servers as the dialer will not + // return an error if the dns server is not responding. + counter := 0 + dialer := &net.Dialer{ + Timeout: 3 * time.Second, + Resolver: &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{ + Timeout: 3 * time.Second, + } + s := dnsServers[counter%len(dnsServers)] + counter++ + return d.DialContext(ctx, network, s) + }, + }, + } + return dialer.DialContext(ctx, network, addr) + } + } client := &http.Client{ Transport: transport, Timeout: time.Second * 10, diff --git a/pkg/issuer/acme/http/http_test.go b/pkg/issuer/acme/http/http_test.go index 4fef74252..4b18f6dea 100644 --- a/pkg/issuer/acme/http/http_test.go +++ b/pkg/issuer/acme/http/http_test.go @@ -19,18 +19,24 @@ package http import ( "context" "fmt" + "net" "net/url" + "strings" + "sync/atomic" "testing" + "github.com/miekg/dns" + cmacme "github.com/jetstack/cert-manager/pkg/apis/acme/v1" + "github.com/jetstack/cert-manager/pkg/controller" ) // countReachabilityTestCalls is a wrapper function that allows us to count the number // of calls to a reachabilityTest. func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest { - return func(ctx context.Context, url *url.URL, key string) error { + return func(ctx context.Context, url *url.URL, key string, dnsServers []string) error { *counter++ - return t(ctx, url, key) + return t(ctx, url, key, dnsServers) } } @@ -44,14 +50,14 @@ func TestCheck(t *testing.T) { tests := []testT{ { name: "should pass", - reachabilityTest: func(context.Context, *url.URL, string) error { + reachabilityTest: func(context.Context, *url.URL, string, []string) error { return nil }, expectedErr: false, }, { name: "should error", - reachabilityTest: func(context.Context, *url.URL, string) error { + reachabilityTest: func(context.Context, *url.URL, string, []string) error { return fmt.Errorf("failed") }, expectedErr: true, @@ -67,6 +73,7 @@ func TestCheck(t *testing.T) { test.challenge = &cmacme.Challenge{} } s := Solver{ + Context: &controller.Context{}, testReachability: countReachabilityTestCalls(&calls, test.reachabilityTest), requiredPasses: requiredCallsForPass, } @@ -87,3 +94,98 @@ func TestCheck(t *testing.T) { }) } } + +func TestReachabilityCustomDnsServers(t *testing.T) { + site := "https://cert-manager.io" + u, err := url.Parse(site) + if err != nil { + t.Fatalf("Failed to parse url %s: %v", site, err) + } + ips, err := net.LookupIP(u.Host) + if err != nil { + t.Fatalf("Failed to resolve %s: %v", u.Host, err) + } + + dnsServerCalled := int32(0) + + server := &dns.Server{Addr: "127.0.0.1:15353", Net: "udp"} + defer server.Shutdown() + + dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { + m := new(dns.Msg) + m.SetReply(r) + + if r.Opcode != dns.OpcodeQuery { + return + } + for _, q := range m.Question { + if q.Name != u.Host+"." { + continue + } + switch q.Qtype { + case dns.TypeA: + t.Logf("A Query for %s\n", q.Name) + atomic.StoreInt32(&dnsServerCalled, 1) + for _, ip := range ips { + if strings.Contains(ip.String(), ":") { + continue + } + rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } + case dns.TypeAAAA: + t.Logf("AAAA Query for %s\n", q.Name) + atomic.StoreInt32(&dnsServerCalled, 1) + for _, ip := range ips { + if !strings.Contains(ip.String(), ":") { + continue + } + rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip)) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } + } + } + if err := w.WriteMsg(m); err != nil { + t.Errorf("failed to write DNS response: %v", err) + } + }) + go server.ListenAndServe() + + key := "there is no key" + + tests := []struct { + name string + dnsServers []string + dnsServerCalled bool + }{ + { + name: "custom dns servers", + dnsServers: []string{"127.0.0.1:15353"}, + dnsServerCalled: true, + }, + { + name: "system dns servers", + dnsServerCalled: false, + }, + } + + for _, tt := range tests { + atomic.StoreInt32(&dnsServerCalled, 0) + err = testReachability(context.Background(), u, key, tt.dnsServers) + switch { + case err == nil: + t.Errorf("Expected error for testReachability, but got none") + case strings.Contains(err.Error(), key): + called := atomic.LoadInt32(&dnsServerCalled) == 1 + if called != tt.dnsServerCalled { + t.Errorf("Expected DNS server called: %v, but got %v", tt.dnsServerCalled, called) + } + default: + t.Errorf("Unexpected error: %v", err) + } + } +}