Merge pull request #976 from hlubek/fix-http01-timeout-and-improve-logging

Respect HTTP01Timeout, improve logging
This commit is contained in:
jetstack-bot 2018-10-23 19:22:01 +01:00 committed by GitHub
commit 222c997acc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 22 deletions

View File

@ -61,7 +61,16 @@ type Solver struct {
requiredPasses int
}
type reachabilityTest func(ctx context.Context, domain, path, key string) (bool, error)
type reachabilityTest func(ctx context.Context, url, key string) (bool, error)
// absorbErr wraps an error to mark it as absorbable (log and handle as nil)
type absorbErr struct {
err error
}
func (ae *absorbErr) Error() string {
return ae.err.Error()
}
// NewSolver returns a new ACME HTTP01 solver for the given Issuer and client.
// TODO: refactor this to have fewer args
@ -92,9 +101,15 @@ func (s *Solver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
func (s *Solver) Check(ch *v1alpha1.Challenge) (bool, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout)
defer cancel()
url := s.buildChallengeUrl(ch)
for i := 0; i < s.requiredPasses; i++ {
ok, err := s.testReachability(ctx, ch.Spec.DNSName, fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token), ch.Spec.Key)
if err != nil {
ok, err := s.testReachability(ctx, url, ch.Spec.Key)
if absorbedErr, wasAbsorbed := err.(*absorbErr); wasAbsorbed {
glog.Infof("could not reach '%s': %v", url, absorbedErr.err)
return false, nil
} else if err != nil {
return false, err
}
if !ok {
@ -115,35 +130,42 @@ func (s *Solver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
return utilerrors.NewAggregate(errs)
}
// testReachability will attempt to connect to the 'domain' with 'path' and
// check if the returned body equals 'key'
func testReachability(ctx context.Context, domain, path, key string) (bool, error) {
func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) string {
url := &url.URL{}
url.Scheme = "http"
url.Host = domain
url.Path = path
url.Host = ch.Spec.DNSName
url.Path = fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token)
response, err := http.Get(url.String())
return url.String()
}
// testReachability will attempt to connect to the 'domain' with 'path' and
// check if the returned body equals 'key'
func testReachability(ctx context.Context, url string, key string) (bool, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
// absorb http client errors
return false, nil
return false, fmt.Errorf("failed to build request: %v", err)
}
req = req.WithContext(ctx)
response, err := http.DefaultClient.Do(req)
if err != nil {
return false, &absorbErr{err: fmt.Errorf("failed to GET '%s': %v", url, err)}
}
if response.StatusCode != http.StatusOK {
// TODO: log this elsewhere
glog.Infof("wrong status code '%d'", response.StatusCode)
return false, nil
return false, &absorbErr{err: fmt.Errorf("wrong status code '%d', expected '%d'", response.StatusCode, http.StatusOK)}
}
defer response.Body.Close()
presentedKey, err := ioutil.ReadAll(response.Body)
if err != nil {
return false, err
return false, fmt.Errorf("failed to read response body: %v", err)
}
if string(presentedKey) != key {
glog.Infof("presented key (%s) did not match expected (%s)", presentedKey, key)
return false, nil
return false, &absorbErr{err: fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key)}
}
return true, nil

View File

@ -27,9 +27,9 @@ import (
// countReachabilityTestCalls is a wrapper function that allows us to count the number
// of calls to a reachabilityTest.
func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest {
return func(ctx context.Context, domain, path, key string) (bool, error) {
return func(ctx context.Context, url, key string) (bool, error) {
*counter++
return t(ctx, domain, path, key)
return t(ctx, url, key)
}
}
@ -44,20 +44,26 @@ func TestCheck(t *testing.T) {
tests := []testT{
{
name: "should pass",
reachabilityTest: func(context.Context, string, string, string) (bool, error) {
reachabilityTest: func(context.Context, string, string) (bool, error) {
return true, nil
},
expectedOk: true,
},
{
name: "should fail",
reachabilityTest: func(context.Context, string, string, string) (bool, error) {
reachabilityTest: func(context.Context, string, string) (bool, error) {
return false, nil
},
},
{
name: "should fail with absorbed error",
reachabilityTest: func(context.Context, string, string) (bool, error) {
return false, &absorbErr{err: fmt.Errorf("failed")}
},
},
{
name: "should error",
reachabilityTest: func(context.Context, string, string, string) (bool, error) {
reachabilityTest: func(context.Context, string, string) (bool, error) {
return false, fmt.Errorf("failed")
},
expectedErr: true,