From d09c293b73e0c89f6e10e6c995952cfed8688ed4 Mon Sep 17 00:00:00 2001 From: Christopher Hlubek Date: Sun, 21 Oct 2018 22:02:26 +0200 Subject: [PATCH] Respect HTTP01Timeout, improve logging Signed-off-by: Christopher Hlubek --- pkg/issuer/acme/http/http.go | 56 +++++++++++++++++++++---------- pkg/issuer/acme/http/http_test.go | 16 ++++++--- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/pkg/issuer/acme/http/http.go b/pkg/issuer/acme/http/http.go index fc3377b36..374e95ed7 100644 --- a/pkg/issuer/acme/http/http.go +++ b/pkg/issuer/acme/http/http.go @@ -61,7 +61,16 @@ type Solver struct { requiredPasses int } -type reachabilityTest func(ctx context.Context, domain, path, key string) (bool, error) +type reachabilityTest func(ctx context.Context, url, key string) (bool, error) + +// absorbErr wraps an error to mark it as absorbable (log and handle as nil) +type absorbErr struct { + err error +} + +func (ae *absorbErr) Error() string { + return ae.err.Error() +} // NewSolver returns a new ACME HTTP01 solver for the given Issuer and client. // TODO: refactor this to have fewer args @@ -92,9 +101,15 @@ func (s *Solver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, ch func (s *Solver) Check(ch *v1alpha1.Challenge) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout) defer cancel() + + url := s.buildChallengeUrl(ch) + for i := 0; i < s.requiredPasses; i++ { - ok, err := s.testReachability(ctx, ch.Spec.DNSName, fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token), ch.Spec.Key) - if err != nil { + ok, err := s.testReachability(ctx, url, ch.Spec.Key) + if absorbedErr, wasAbsorbed := err.(*absorbErr); wasAbsorbed { + glog.Infof("could not reach '%s': %v", url, absorbedErr.err) + return false, nil + } else if err != nil { return false, err } if !ok { @@ -115,35 +130,42 @@ func (s *Solver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, ch return utilerrors.NewAggregate(errs) } -// testReachability will attempt to connect to the 'domain' with 'path' and -// check if the returned body equals 'key' -func testReachability(ctx context.Context, domain, path, key string) (bool, error) { +func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) string { url := &url.URL{} url.Scheme = "http" - url.Host = domain - url.Path = path + url.Host = ch.Spec.DNSName + url.Path = fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token) - response, err := http.Get(url.String()) + return url.String() +} + +// testReachability will attempt to connect to the 'domain' with 'path' and +// check if the returned body equals 'key' +func testReachability(ctx context.Context, url string, key string) (bool, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - // absorb http client errors - return false, nil + return false, fmt.Errorf("failed to build request: %v", err) + } + + req = req.WithContext(ctx) + + response, err := http.DefaultClient.Do(req) + if err != nil { + return false, &absorbErr{err: fmt.Errorf("failed to GET '%s': %v", url, err)} } if response.StatusCode != http.StatusOK { - // TODO: log this elsewhere - glog.Infof("wrong status code '%d'", response.StatusCode) - return false, nil + return false, &absorbErr{err: fmt.Errorf("wrong status code '%d', expected '%d'", response.StatusCode, http.StatusOK)} } defer response.Body.Close() presentedKey, err := ioutil.ReadAll(response.Body) if err != nil { - return false, err + return false, fmt.Errorf("failed to read response body: %v", err) } if string(presentedKey) != key { - glog.Infof("presented key (%s) did not match expected (%s)", presentedKey, key) - return false, nil + return false, &absorbErr{err: fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key)} } return true, nil diff --git a/pkg/issuer/acme/http/http_test.go b/pkg/issuer/acme/http/http_test.go index 4340109de..4156dc255 100644 --- a/pkg/issuer/acme/http/http_test.go +++ b/pkg/issuer/acme/http/http_test.go @@ -27,9 +27,9 @@ import ( // countReachabilityTestCalls is a wrapper function that allows us to count the number // of calls to a reachabilityTest. func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest { - return func(ctx context.Context, domain, path, key string) (bool, error) { + return func(ctx context.Context, url, key string) (bool, error) { *counter++ - return t(ctx, domain, path, key) + return t(ctx, url, key) } } @@ -44,20 +44,26 @@ func TestCheck(t *testing.T) { tests := []testT{ { name: "should pass", - reachabilityTest: func(context.Context, string, string, string) (bool, error) { + reachabilityTest: func(context.Context, string, string) (bool, error) { return true, nil }, expectedOk: true, }, { name: "should fail", - reachabilityTest: func(context.Context, string, string, string) (bool, error) { + reachabilityTest: func(context.Context, string, string) (bool, error) { return false, nil }, }, + { + name: "should fail with absorbed error", + reachabilityTest: func(context.Context, string, string) (bool, error) { + return false, &absorbErr{err: fmt.Errorf("failed")} + }, + }, { name: "should error", - reachabilityTest: func(context.Context, string, string, string) (bool, error) { + reachabilityTest: func(context.Context, string, string) (bool, error) { return false, fmt.Errorf("failed") }, expectedErr: true,