Respect HTTP01Timeout, improve logging

Signed-off-by: Christopher Hlubek <hlubek@networkteam.com>
This commit is contained in:
Christopher Hlubek 2018-10-21 22:02:26 +02:00
parent 3347fcc613
commit d09c293b73
2 changed files with 50 additions and 22 deletions

View File

@ -61,7 +61,16 @@ type Solver struct {
requiredPasses int requiredPasses int
} }
type reachabilityTest func(ctx context.Context, domain, path, key string) (bool, error) type reachabilityTest func(ctx context.Context, url, key string) (bool, error)
// absorbErr wraps an error to mark it as absorbable (log and handle as nil)
type absorbErr struct {
err error
}
func (ae *absorbErr) Error() string {
return ae.err.Error()
}
// NewSolver returns a new ACME HTTP01 solver for the given Issuer and client. // NewSolver returns a new ACME HTTP01 solver for the given Issuer and client.
// TODO: refactor this to have fewer args // TODO: refactor this to have fewer args
@ -92,9 +101,15 @@ func (s *Solver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
func (s *Solver) Check(ch *v1alpha1.Challenge) (bool, error) { func (s *Solver) Check(ch *v1alpha1.Challenge) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout) ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout)
defer cancel() defer cancel()
url := s.buildChallengeUrl(ch)
for i := 0; i < s.requiredPasses; i++ { for i := 0; i < s.requiredPasses; i++ {
ok, err := s.testReachability(ctx, ch.Spec.DNSName, fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token), ch.Spec.Key) ok, err := s.testReachability(ctx, url, ch.Spec.Key)
if err != nil { if absorbedErr, wasAbsorbed := err.(*absorbErr); wasAbsorbed {
glog.Infof("could not reach '%s': %v", url, absorbedErr.err)
return false, nil
} else if err != nil {
return false, err return false, err
} }
if !ok { if !ok {
@ -115,35 +130,42 @@ func (s *Solver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
return utilerrors.NewAggregate(errs) return utilerrors.NewAggregate(errs)
} }
// testReachability will attempt to connect to the 'domain' with 'path' and func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) string {
// check if the returned body equals 'key'
func testReachability(ctx context.Context, domain, path, key string) (bool, error) {
url := &url.URL{} url := &url.URL{}
url.Scheme = "http" url.Scheme = "http"
url.Host = domain url.Host = ch.Spec.DNSName
url.Path = path url.Path = fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token)
response, err := http.Get(url.String()) return url.String()
}
// testReachability will attempt to connect to the 'domain' with 'path' and
// check if the returned body equals 'key'
func testReachability(ctx context.Context, url string, key string) (bool, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil { if err != nil {
// absorb http client errors return false, fmt.Errorf("failed to build request: %v", err)
return false, nil }
req = req.WithContext(ctx)
response, err := http.DefaultClient.Do(req)
if err != nil {
return false, &absorbErr{err: fmt.Errorf("failed to GET '%s': %v", url, err)}
} }
if response.StatusCode != http.StatusOK { if response.StatusCode != http.StatusOK {
// TODO: log this elsewhere return false, &absorbErr{err: fmt.Errorf("wrong status code '%d', expected '%d'", response.StatusCode, http.StatusOK)}
glog.Infof("wrong status code '%d'", response.StatusCode)
return false, nil
} }
defer response.Body.Close() defer response.Body.Close()
presentedKey, err := ioutil.ReadAll(response.Body) presentedKey, err := ioutil.ReadAll(response.Body)
if err != nil { if err != nil {
return false, err return false, fmt.Errorf("failed to read response body: %v", err)
} }
if string(presentedKey) != key { if string(presentedKey) != key {
glog.Infof("presented key (%s) did not match expected (%s)", presentedKey, key) return false, &absorbErr{err: fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key)}
return false, nil
} }
return true, nil return true, nil

View File

@ -27,9 +27,9 @@ import (
// countReachabilityTestCalls is a wrapper function that allows us to count the number // countReachabilityTestCalls is a wrapper function that allows us to count the number
// of calls to a reachabilityTest. // of calls to a reachabilityTest.
func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest { func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest {
return func(ctx context.Context, domain, path, key string) (bool, error) { return func(ctx context.Context, url, key string) (bool, error) {
*counter++ *counter++
return t(ctx, domain, path, key) return t(ctx, url, key)
} }
} }
@ -44,20 +44,26 @@ func TestCheck(t *testing.T) {
tests := []testT{ tests := []testT{
{ {
name: "should pass", name: "should pass",
reachabilityTest: func(context.Context, string, string, string) (bool, error) { reachabilityTest: func(context.Context, string, string) (bool, error) {
return true, nil return true, nil
}, },
expectedOk: true, expectedOk: true,
}, },
{ {
name: "should fail", name: "should fail",
reachabilityTest: func(context.Context, string, string, string) (bool, error) { reachabilityTest: func(context.Context, string, string) (bool, error) {
return false, nil return false, nil
}, },
}, },
{
name: "should fail with absorbed error",
reachabilityTest: func(context.Context, string, string) (bool, error) {
return false, &absorbErr{err: fmt.Errorf("failed")}
},
},
{ {
name: "should error", name: "should error",
reachabilityTest: func(context.Context, string, string, string) (bool, error) { reachabilityTest: func(context.Context, string, string) (bool, error) {
return false, fmt.Errorf("failed") return false, fmt.Errorf("failed")
}, },
expectedErr: true, expectedErr: true,