diff --git a/pkg/controller/acmechallenges/sync.go b/pkg/controller/acmechallenges/sync.go index 3b78a4e33..4cf857842 100644 --- a/pkg/controller/acmechallenges/sync.go +++ b/pkg/controller/acmechallenges/sync.go @@ -42,10 +42,8 @@ const ( type solver interface { // Present the challenge value with the given solver. Present(ctx context.Context, issuer cmapi.GenericIssuer, ch *cmapi.Challenge) error - // Check should return Error only if propagation check cannot be performed. - // It MUST return `false, nil` if can contact all relevant services and all is - // doing is waiting for propagation - Check(ctx context.Context, issuer cmapi.GenericIssuer, ch *cmapi.Challenge) (bool, error) + // Check returns an Error if the propagation check didn't succeed. + Check(ctx context.Context, issuer cmapi.GenericIssuer, ch *cmapi.Challenge) error // CleanUp will remove challenge records for a given solver. // This may involve deleting resources in the Kubernetes API Server, or // communicating with other external components (e.g. DNS providers). @@ -148,12 +146,10 @@ func (c *Controller) Sync(ctx context.Context, ch *cmapi.Challenge) (err error) c.Recorder.Eventf(ch, corev1.EventTypeNormal, "Presented", "Presented challenge using %s challenge mechanism", ch.Spec.Type) } - ok, err := solver.Check(ctx, genericIssuer, ch) + err = solver.Check(ctx, genericIssuer, ch) if err != nil { - return err - } - if !ok { - ch.Status.Reason = fmt.Sprintf("Waiting for %s challenge propagation", ch.Spec.Type) + glog.Infof("propagation check failed: %v", err) + ch.Status.Reason = fmt.Sprintf("Waiting for %s challenge propagation: %s", ch.Spec.Type, err) key, err := controllerpkg.KeyFunc(ch) // This is an unexpected edge case and should never occur diff --git a/pkg/controller/acmechallenges/sync_test.go b/pkg/controller/acmechallenges/sync_test.go index d0911d5b0..1237678a7 100644 --- a/pkg/controller/acmechallenges/sync_test.go +++ b/pkg/controller/acmechallenges/sync_test.go @@ -18,6 +18,7 @@ package acmechallenges import ( "context" + "fmt" "testing" "k8s.io/apimachinery/pkg/runtime" @@ -38,7 +39,7 @@ func (f *fakeSolver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, // Check should return Error only if propagation check cannot be performed. // It MUST return `false, nil` if can contact all relevant services and all is // doing is waiting for propagation -func (f *fakeSolver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) { +func (f *fakeSolver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { return f.fakeCheck(ctx, issuer, ch) } @@ -51,7 +52,7 @@ func (f *fakeSolver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, type fakeSolver struct { fakePresent func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error - fakeCheck func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) + fakeCheck func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error fakeCleanUp func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error } @@ -108,8 +109,8 @@ func TestSyncHappyPath(t *testing.T) { fakePresent: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { return nil }, - fakeCheck: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) { - return false, nil + fakeCheck: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { + return fmt.Errorf("some error") }, }, Builder: &testpkg.Builder{ @@ -127,7 +128,7 @@ func TestSyncHappyPath(t *testing.T) { gen.SetChallengeState(v1alpha1.Pending), gen.SetChallengePresented(true), gen.SetChallengeType("http-01"), - gen.SetChallengeReason("Waiting for http-01 challenge propagation"), + gen.SetChallengeReason("Waiting for http-01 challenge propagation: some error"), ))), }, }, @@ -146,8 +147,8 @@ func TestSyncHappyPath(t *testing.T) { gen.SetChallengePresented(true), ), HTTP01: &fakeSolver{ - fakeCheck: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) { - return true, nil + fakeCheck: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { + return nil }, fakeCleanUp: func(context.Context, v1alpha1.GenericIssuer, *v1alpha1.Challenge) error { return nil @@ -198,8 +199,8 @@ func TestSyncHappyPath(t *testing.T) { gen.SetChallengePresented(true), ), HTTP01: &fakeSolver{ - fakeCheck: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) { - return true, nil + fakeCheck: func(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { + return nil }, fakeCleanUp: func(context.Context, v1alpha1.GenericIssuer, *v1alpha1.Challenge) error { return nil diff --git a/pkg/issuer/acme/dns/dns.go b/pkg/issuer/acme/dns/dns.go index 9c4f307af..88401fca3 100644 --- a/pkg/issuer/acme/dns/dns.go +++ b/pkg/issuer/acme/dns/dns.go @@ -91,20 +91,11 @@ func (s *Solver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, ch } // Check verifies that the DNS records for the ACME challenge have propagated. -func (s *Solver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) { - providerName := ch.Spec.Config.DNS01.Provider - if providerName == "" { - return false, fmt.Errorf("dns01 challenge provider name must be set") - } +func (s *Solver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { - providerConfig, err := issuer.GetSpec().ACME.DNS01.Provider(providerName) + fqdn, value, ttl, err := util.DNS01Record(ch.Spec.DNSName, ch.Spec.Key, s.DNS01Nameservers, false) if err != nil { - return false, err - } - - fqdn, value, ttl, err := util.DNS01Record(ch.Spec.DNSName, ch.Spec.Key, s.DNS01Nameservers, followCNAME(providerConfig.CNAMEStrategy)) - if err != nil { - return false, err + return err } glog.Infof("Checking DNS propagation for %q using name servers: %v", ch.Spec.DNSName, s.Context.DNS01Nameservers) @@ -112,18 +103,17 @@ func (s *Solver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v ok, err := util.PreCheckDNS(fqdn, value, s.Context.DNS01Nameservers, s.Context.DNS01CheckAuthoritative) if err != nil { - return false, err + return err } if !ok { - glog.Infof("DNS record for %q not yet propagated", ch.Spec.DNSName) - return false, nil + return fmt.Errorf("DNS record for %q not yet propagated", ch.Spec.DNSName) } glog.Infof("Waiting DNS record TTL (%ds) to allow propagation of DNS record for domain %q", ttl, fqdn) time.Sleep(time.Second * time.Duration(ttl)) glog.Infof("ACME DNS01 validation record propagated for %q", fqdn) - return true, nil + return nil } // CleanUp removes DNS records which are no longer needed after diff --git a/pkg/issuer/acme/http/http.go b/pkg/issuer/acme/http/http.go index fcee34983..32bd62243 100644 --- a/pkg/issuer/acme/http/http.go +++ b/pkg/issuer/acme/http/http.go @@ -25,7 +25,6 @@ import ( "net/url" "time" - "github.com/golang/glog" utilerrors "k8s.io/apimachinery/pkg/util/errors" corev1listers "k8s.io/client-go/listers/core/v1" extv1beta1listers "k8s.io/client-go/listers/extensions/v1beta1" @@ -62,16 +61,7 @@ type Solver struct { requiredPasses int } -type reachabilityTest func(ctx context.Context, url, key string) (bool, error) - -// absorbErr wraps an error to mark it as absorbable (log and handle as nil) -type absorbErr struct { - err error -} - -func (ae *absorbErr) Error() string { - return ae.err.Error() -} +type reachabilityTest func(ctx context.Context, url *url.URL, key string) error // NewSolver returns a new ACME HTTP01 solver for the given Issuer and client. // TODO: refactor this to have fewer args @@ -99,26 +89,20 @@ func (s *Solver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, ch return utilerrors.NewAggregate([]error{podErr, svcErr, ingressErr}) } -func (s *Solver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) (bool, error) { - ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout) +func (s *Solver) Check(ctx context.Context, issuer v1alpha1.GenericIssuer, ch *v1alpha1.Challenge) error { + ctx, cancel := context.WithTimeout(ctx, HTTP01Timeout) defer cancel() url := s.buildChallengeUrl(ch) for i := 0; i < s.requiredPasses; i++ { - ok, err := s.testReachability(ctx, url, ch.Spec.Key) - if absorbedErr, wasAbsorbed := err.(*absorbErr); wasAbsorbed { - glog.Infof("could not reach '%s': %v", url, absorbedErr.err) - return false, nil - } else if err != nil { - return false, err - } - if !ok { - return false, nil + err := s.testReachability(ctx, url, ch.Spec.Key) + if err != nil { + return err } time.Sleep(time.Second * 2) } - return true, nil + return nil } // CleanUp will ensure the created service, ingress and pod are clean/deleted of any @@ -131,21 +115,21 @@ func (s *Solver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, ch return utilerrors.NewAggregate(errs) } -func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) string { +func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) *url.URL { url := &url.URL{} url.Scheme = "http" url.Host = ch.Spec.DNSName url.Path = fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token) - return url.String() + return url } // testReachability will attempt to connect to the 'domain' with 'path' and // check if the returned body equals 'key' -func testReachability(ctx context.Context, url string, key string) (bool, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return false, fmt.Errorf("failed to build request: %v", err) +func testReachability(ctx context.Context, url *url.URL, key string) error { + req := &http.Request{ + Method: http.MethodGet, + URL: url, } req = req.WithContext(ctx) @@ -170,22 +154,22 @@ func testReachability(ctx context.Context, url string, key string) (bool, error) response, err := client.Do(req) if err != nil { - return false, &absorbErr{err: fmt.Errorf("failed to GET '%s': %v", url, err)} + return fmt.Errorf("failed to GET '%s': %v", url, err) } if response.StatusCode != http.StatusOK { - return false, &absorbErr{err: fmt.Errorf("wrong status code '%d', expected '%d'", response.StatusCode, http.StatusOK)} + return fmt.Errorf("wrong status code '%d', expected '%d'", response.StatusCode, http.StatusOK) } defer response.Body.Close() presentedKey, err := ioutil.ReadAll(response.Body) if err != nil { - return false, fmt.Errorf("failed to read response body: %v", err) + return fmt.Errorf("failed to read response body: %v", err) } if string(presentedKey) != key { - return false, &absorbErr{err: fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key)} + return fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key) } - return true, nil + return nil } diff --git a/pkg/issuer/acme/http/http_test.go b/pkg/issuer/acme/http/http_test.go index 90c6439fc..44c8539b5 100644 --- a/pkg/issuer/acme/http/http_test.go +++ b/pkg/issuer/acme/http/http_test.go @@ -19,6 +19,7 @@ package http import ( "context" "fmt" + "net/url" "testing" "github.com/jetstack/cert-manager/pkg/apis/certmanager/v1alpha1" @@ -27,7 +28,7 @@ import ( // countReachabilityTestCalls is a wrapper function that allows us to count the number // of calls to a reachabilityTest. func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest { - return func(ctx context.Context, url, key string) (bool, error) { + return func(ctx context.Context, url *url.URL, key string) error { *counter++ return t(ctx, url, key) } @@ -39,32 +40,19 @@ func TestCheck(t *testing.T) { reachabilityTest reachabilityTest challenge *v1alpha1.Challenge expectedErr bool - expectedOk bool } tests := []testT{ { name: "should pass", - reachabilityTest: func(context.Context, string, string) (bool, error) { - return true, nil - }, - expectedOk: true, - }, - { - name: "should fail", - reachabilityTest: func(context.Context, string, string) (bool, error) { - return false, nil - }, - }, - { - name: "should fail with absorbed error", - reachabilityTest: func(context.Context, string, string) (bool, error) { - return false, &absorbErr{err: fmt.Errorf("failed")} + reachabilityTest: func(context.Context, *url.URL, string) error { + return nil }, + expectedErr: false, }, { name: "should error", - reachabilityTest: func(context.Context, string, string) (bool, error) { - return false, fmt.Errorf("failed") + reachabilityTest: func(context.Context, *url.URL, string) error { + return fmt.Errorf("failed") }, expectedErr: true, }, @@ -83,7 +71,7 @@ func TestCheck(t *testing.T) { requiredPasses: requiredCallsForPass, } - ok, err := s.Check(nil, nil, test.challenge) + err := s.Check(context.Background(), nil, test.challenge) if err != nil && !test.expectedErr { t.Errorf("Expected Check to return non-nil error, but got %v", err) return @@ -92,10 +80,7 @@ func TestCheck(t *testing.T) { t.Errorf("Expected error from Check, but got none") return } - if test.expectedOk != ok { - t.Errorf("Expected ok=%t but got ok=%t", test.expectedOk, ok) - } - if test.expectedOk && calls != requiredCallsForPass { + if !test.expectedErr && calls != requiredCallsForPass { t.Errorf("Expected Wait to verify reachability test passes %d times, but only checked %d", requiredCallsForPass, calls) return }