Merge pull request #4287 from linka-cloud/acme-http-challenge-cutomer-dns

Acme http challenge custom dns
This commit is contained in:
jetstack-bot 2022-01-11 11:24:03 +00:00 committed by GitHub
commit fa321b6a4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 148 additions and 9 deletions

View File

@ -66,7 +66,7 @@ const controllerAgentName = "cert-manager"
// This sets the informer's resync period to 10 hours
// following the controller-runtime defaults
//and following discussion: https://github.com/kubernetes-sigs/controller-runtime/pull/88#issuecomment-408500629
// and following discussion: https://github.com/kubernetes-sigs/controller-runtime/pull/88#issuecomment-408500629
const resyncPeriod = 10 * time.Hour
func Run(opts *options.ControllerOptions, stopCh <-chan struct{}) error {
@ -360,6 +360,7 @@ func buildControllerContext(ctx context.Context, opts *options.ControllerOptions
HTTP01SolverResourceRequestMemory: HTTP01SolverResourceRequestMemory,
HTTP01SolverResourceLimitsCPU: HTTP01SolverResourceLimitsCPU,
HTTP01SolverResourceLimitsMemory: HTTP01SolverResourceLimitsMemory,
HTTP01SolverNameservers: opts.ACMEHTTP01SolverNameservers,
DNS01CheckAuthoritative: !opts.DNS01RecursiveNameserversOnly,
DNS01Nameservers: nameservers,
AccountRegistry: acmeAccountRegistry,

View File

@ -79,6 +79,8 @@ type ControllerOptions struct {
ACMEHTTP01SolverResourceRequestMemory string
ACMEHTTP01SolverResourceLimitsCPU string
ACMEHTTP01SolverResourceLimitsMemory string
// Allows specifying a list of custom nameservers to perform HTTP01 checks on.
ACMEHTTP01SolverNameservers []string
ClusterIssuerAmbientCredentials bool
IssuerAmbientCredentials bool
@ -232,6 +234,7 @@ func NewControllerOptions() *ControllerOptions {
DefaultIssuerKind: defaultTLSACMEIssuerKind,
DefaultIssuerGroup: defaultTLSACMEIssuerGroup,
DefaultAutoCertificateAnnotations: defaultAutoCertificateAnnotations,
ACMEHTTP01SolverNameservers: []string{},
DNS01RecursiveNameservers: []string{},
DNS01RecursiveNameserversOnly: defaultDNS01RecursiveNameserversOnly,
EnableCertificateOwnerRef: defaultEnableCertificateOwnerRef,
@ -298,6 +301,11 @@ func (s *ControllerOptions) AddFlags(fs *pflag.FlagSet) {
fs.StringVar(&s.ACMEHTTP01SolverResourceLimitsMemory, "acme-http01-solver-resource-limits-memory", defaultACMEHTTP01SolverResourceLimitsMemory, ""+
"Defines the resource limits Memory size when spawning new ACME HTTP01 challenge solver pods.")
fs.StringSliceVar(&s.ACMEHTTP01SolverNameservers, "acme-http01-solver-nameservers",
[]string{}, "A list of comma separated dns server endpoints used for "+
"ACME HTTP01 check requests. This should be a list containing host and "+
"port, for example 8.8.8.8:53,8.8.4.4:53")
fs.BoolVar(&s.ClusterIssuerAmbientCredentials, "cluster-issuer-ambient-credentials", defaultClusterIssuerAmbientCredentials, ""+
"Whether a cluster-issuer may make use of ambient credentials for issuers. 'Ambient Credentials' are credentials drawn from the environment, metadata services, or local files which are not explicitly configured in the ClusterIssuer API object. "+
"When this flag is enabled, the following sources for credentials are also used: "+
@ -369,7 +377,7 @@ func (o *ControllerOptions) Validate() error {
return fmt.Errorf("invalid value for kube-api-burst: %v must be higher or equal to kube-api-qps: %v", o.KubernetesAPIQPS, o.KubernetesAPIQPS)
}
for _, server := range o.DNS01RecursiveNameservers {
for _, server := range append(o.DNS01RecursiveNameservers, o.ACMEHTTP01SolverNameservers...) {
// ensure all servers have a port number
_, _, err := net.SplitHostPort(server)
if err != nil {

View File

@ -123,6 +123,10 @@ type ACMEOptions struct {
// HTTP01SolverResourceLimitsMemory defines the ACME pod's resource limits Memory size
HTTP01SolverResourceLimitsMemory resource.Quantity
// HTTP01SolverNameservers is a list of nameservers to use when performing self-checks
// for ACME HTTP01 validations.
HTTP01SolverNameservers []string
// DNS01CheckAuthoritative is a flag for controlling if auth nss are used
// for checking propagation of an RR. This is the ideal scenario
DNS01CheckAuthoritative bool

View File

@ -48,8 +48,10 @@ go_test(
embed = [":go_default_library"],
deps = [
"//pkg/apis/acme/v1:go_default_library",
"//pkg/controller:go_default_library",
"//pkg/controller/test:go_default_library",
"//test/unit/gen:go_default_library",
"@com_github_miekg_dns//:go_default_library",
"@io_k8s_api//core/v1:go_default_library",
"@io_k8s_api//networking/v1:go_default_library",
"@io_k8s_apimachinery//pkg/api/errors:go_default_library",

View File

@ -70,7 +70,7 @@ type Solver struct {
requiredPasses int
}
type reachabilityTest func(ctx context.Context, url *url.URL, key string) error
type reachabilityTest func(ctx context.Context, url *url.URL, key string, dnsServers []string) error
// NewSolver returns a new ACME HTTP01 solver for the given *controller.Context.
func NewSolver(ctx *controller.Context) (*Solver, error) {
@ -172,7 +172,7 @@ func (s *Solver) Check(ctx context.Context, issuer v1.GenericIssuer, ch *cmacme.
log.V(logf.DebugLevel).Info("running self check multiple times to ensure challenge has propagated", "required_passes", s.requiredPasses)
for i := 0; i < s.requiredPasses; i++ {
err := s.testReachability(ctx, url, ch.Spec.Key)
err := s.testReachability(ctx, url, ch.Spec.Key, s.HTTP01SolverNameservers)
if err != nil {
return err
}
@ -211,7 +211,7 @@ func (s *Solver) buildChallengeUrl(ch *cmacme.Challenge) *url.URL {
// testReachability will attempt to connect to the 'domain' with 'path' and
// check if the returned body equals 'key'
func testReachability(ctx context.Context, url *url.URL, key string) error {
func testReachability(ctx context.Context, url *url.URL, key string, dnsServers []string) error {
log := logf.FromContext(ctx)
log.V(logf.DebugLevel).Info("performing HTTP01 reachability check")
@ -269,6 +269,28 @@ func testReachability(ctx context.Context, url *url.URL, key string) error {
},
}
if len(dnsServers) != 0 {
transport.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
// we need to increment a counter to iterate through the dns servers as the dialer will not
// return an error if the dns server is not responding.
counter := 0
dialer := &net.Dialer{
Timeout: 3 * time.Second,
Resolver: &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: 3 * time.Second,
}
s := dnsServers[counter%len(dnsServers)]
counter++
return d.DialContext(ctx, network, s)
},
},
}
return dialer.DialContext(ctx, network, addr)
}
}
client := &http.Client{
Transport: transport,
Timeout: time.Second * 10,

View File

@ -19,18 +19,24 @@ package http
import (
"context"
"fmt"
"net"
"net/url"
"strings"
"sync/atomic"
"testing"
"github.com/miekg/dns"
cmacme "github.com/jetstack/cert-manager/pkg/apis/acme/v1"
"github.com/jetstack/cert-manager/pkg/controller"
)
// countReachabilityTestCalls is a wrapper function that allows us to count the number
// of calls to a reachabilityTest.
func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest {
return func(ctx context.Context, url *url.URL, key string) error {
return func(ctx context.Context, url *url.URL, key string, dnsServers []string) error {
*counter++
return t(ctx, url, key)
return t(ctx, url, key, dnsServers)
}
}
@ -44,14 +50,14 @@ func TestCheck(t *testing.T) {
tests := []testT{
{
name: "should pass",
reachabilityTest: func(context.Context, *url.URL, string) error {
reachabilityTest: func(context.Context, *url.URL, string, []string) error {
return nil
},
expectedErr: false,
},
{
name: "should error",
reachabilityTest: func(context.Context, *url.URL, string) error {
reachabilityTest: func(context.Context, *url.URL, string, []string) error {
return fmt.Errorf("failed")
},
expectedErr: true,
@ -67,6 +73,7 @@ func TestCheck(t *testing.T) {
test.challenge = &cmacme.Challenge{}
}
s := Solver{
Context: &controller.Context{},
testReachability: countReachabilityTestCalls(&calls, test.reachabilityTest),
requiredPasses: requiredCallsForPass,
}
@ -87,3 +94,98 @@ func TestCheck(t *testing.T) {
})
}
}
func TestReachabilityCustomDnsServers(t *testing.T) {
site := "https://cert-manager.io"
u, err := url.Parse(site)
if err != nil {
t.Fatalf("Failed to parse url %s: %v", site, err)
}
ips, err := net.LookupIP(u.Host)
if err != nil {
t.Fatalf("Failed to resolve %s: %v", u.Host, err)
}
dnsServerCalled := int32(0)
server := &dns.Server{Addr: "127.0.0.1:15353", Net: "udp"}
defer server.Shutdown()
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
if r.Opcode != dns.OpcodeQuery {
return
}
for _, q := range m.Question {
if q.Name != u.Host+"." {
continue
}
switch q.Qtype {
case dns.TypeA:
t.Logf("A Query for %s\n", q.Name)
atomic.StoreInt32(&dnsServerCalled, 1)
for _, ip := range ips {
if strings.Contains(ip.String(), ":") {
continue
}
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeAAAA:
t.Logf("AAAA Query for %s\n", q.Name)
atomic.StoreInt32(&dnsServerCalled, 1)
for _, ip := range ips {
if !strings.Contains(ip.String(), ":") {
continue
}
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
if err := w.WriteMsg(m); err != nil {
t.Errorf("failed to write DNS response: %v", err)
}
})
go server.ListenAndServe()
key := "there is no key"
tests := []struct {
name string
dnsServers []string
dnsServerCalled bool
}{
{
name: "custom dns servers",
dnsServers: []string{"127.0.0.1:15353"},
dnsServerCalled: true,
},
{
name: "system dns servers",
dnsServerCalled: false,
},
}
for _, tt := range tests {
atomic.StoreInt32(&dnsServerCalled, 0)
err = testReachability(context.Background(), u, key, tt.dnsServers)
switch {
case err == nil:
t.Errorf("Expected error for testReachability, but got none")
case strings.Contains(err.Error(), key):
called := atomic.LoadInt32(&dnsServerCalled) == 1
if called != tt.dnsServerCalled {
t.Errorf("Expected DNS server called: %v, but got %v", tt.dnsServerCalled, called)
}
default:
t.Errorf("Unexpected error: %v", err)
}
}
}