Merge pull request #976 from hlubek/fix-http01-timeout-and-improve-logging
Respect HTTP01Timeout, improve logging
This commit is contained in:
commit
222c997acc
@ -61,7 +61,16 @@ type Solver struct {
|
||||
requiredPasses int
|
||||
}
|
||||
|
||||
type reachabilityTest func(ctx context.Context, domain, path, key string) (bool, error)
|
||||
type reachabilityTest func(ctx context.Context, url, key string) (bool, error)
|
||||
|
||||
// absorbErr wraps an error to mark it as absorbable (log and handle as nil)
|
||||
type absorbErr struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (ae *absorbErr) Error() string {
|
||||
return ae.err.Error()
|
||||
}
|
||||
|
||||
// NewSolver returns a new ACME HTTP01 solver for the given Issuer and client.
|
||||
// TODO: refactor this to have fewer args
|
||||
@ -92,9 +101,15 @@ func (s *Solver) Present(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
|
||||
func (s *Solver) Check(ch *v1alpha1.Challenge) (bool, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout)
|
||||
defer cancel()
|
||||
|
||||
url := s.buildChallengeUrl(ch)
|
||||
|
||||
for i := 0; i < s.requiredPasses; i++ {
|
||||
ok, err := s.testReachability(ctx, ch.Spec.DNSName, fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token), ch.Spec.Key)
|
||||
if err != nil {
|
||||
ok, err := s.testReachability(ctx, url, ch.Spec.Key)
|
||||
if absorbedErr, wasAbsorbed := err.(*absorbErr); wasAbsorbed {
|
||||
glog.Infof("could not reach '%s': %v", url, absorbedErr.err)
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !ok {
|
||||
@ -115,35 +130,42 @@ func (s *Solver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
|
||||
return utilerrors.NewAggregate(errs)
|
||||
}
|
||||
|
||||
// testReachability will attempt to connect to the 'domain' with 'path' and
|
||||
// check if the returned body equals 'key'
|
||||
func testReachability(ctx context.Context, domain, path, key string) (bool, error) {
|
||||
func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) string {
|
||||
url := &url.URL{}
|
||||
url.Scheme = "http"
|
||||
url.Host = domain
|
||||
url.Path = path
|
||||
url.Host = ch.Spec.DNSName
|
||||
url.Path = fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token)
|
||||
|
||||
response, err := http.Get(url.String())
|
||||
return url.String()
|
||||
}
|
||||
|
||||
// testReachability will attempt to connect to the 'domain' with 'path' and
|
||||
// check if the returned body equals 'key'
|
||||
func testReachability(ctx context.Context, url string, key string) (bool, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
// absorb http client errors
|
||||
return false, nil
|
||||
return false, fmt.Errorf("failed to build request: %v", err)
|
||||
}
|
||||
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, &absorbErr{err: fmt.Errorf("failed to GET '%s': %v", url, err)}
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
// TODO: log this elsewhere
|
||||
glog.Infof("wrong status code '%d'", response.StatusCode)
|
||||
return false, nil
|
||||
return false, &absorbErr{err: fmt.Errorf("wrong status code '%d', expected '%d'", response.StatusCode, http.StatusOK)}
|
||||
}
|
||||
|
||||
defer response.Body.Close()
|
||||
presentedKey, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
if string(presentedKey) != key {
|
||||
glog.Infof("presented key (%s) did not match expected (%s)", presentedKey, key)
|
||||
return false, nil
|
||||
return false, &absorbErr{err: fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key)}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
|
||||
@ -27,9 +27,9 @@ import (
|
||||
// countReachabilityTestCalls is a wrapper function that allows us to count the number
|
||||
// of calls to a reachabilityTest.
|
||||
func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest {
|
||||
return func(ctx context.Context, domain, path, key string) (bool, error) {
|
||||
return func(ctx context.Context, url, key string) (bool, error) {
|
||||
*counter++
|
||||
return t(ctx, domain, path, key)
|
||||
return t(ctx, url, key)
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,20 +44,26 @@ func TestCheck(t *testing.T) {
|
||||
tests := []testT{
|
||||
{
|
||||
name: "should pass",
|
||||
reachabilityTest: func(context.Context, string, string, string) (bool, error) {
|
||||
reachabilityTest: func(context.Context, string, string) (bool, error) {
|
||||
return true, nil
|
||||
},
|
||||
expectedOk: true,
|
||||
},
|
||||
{
|
||||
name: "should fail",
|
||||
reachabilityTest: func(context.Context, string, string, string) (bool, error) {
|
||||
reachabilityTest: func(context.Context, string, string) (bool, error) {
|
||||
return false, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should fail with absorbed error",
|
||||
reachabilityTest: func(context.Context, string, string) (bool, error) {
|
||||
return false, &absorbErr{err: fmt.Errorf("failed")}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should error",
|
||||
reachabilityTest: func(context.Context, string, string, string) (bool, error) {
|
||||
reachabilityTest: func(context.Context, string, string) (bool, error) {
|
||||
return false, fmt.Errorf("failed")
|
||||
},
|
||||
expectedErr: true,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user