diff --git a/pkg/issuer/acme/http/http.go b/pkg/issuer/acme/http/http.go index bec93be62..e2ca77a7a 100644 --- a/pkg/issuer/acme/http/http.go +++ b/pkg/issuer/acme/http/http.go @@ -2,7 +2,6 @@ package http import ( "context" - "errors" "fmt" "io/ioutil" "net/http" @@ -11,6 +10,7 @@ import ( "sync" "time" + "github.com/golang/glog" corev1 "k8s.io/api/core/v1" extv1beta1 "k8s.io/api/extensions/v1beta1" k8sErrors "k8s.io/apimachinery/pkg/api/errors" @@ -71,7 +71,7 @@ type Solver struct { lock sync.Mutex } -type reachabilityTest func(ctx context.Context, domain, path, key string) error +type reachabilityTest func(ctx context.Context, domain, path, key string) (bool, error) // NewSolver returns a new ACME HTTP01 solver for the given Issuer and client. func NewSolver(issuer v1alpha1.GenericIssuer, client kubernetes.Interface, secretLister corev1listers.SecretLister, solverImage string) *Solver { @@ -391,10 +391,13 @@ func (s *Solver) Check(domain, token, key string) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), HTTP01Timeout) defer cancel() for i := 0; i < s.requiredPasses; i++ { - err := s.testReachability(ctx, domain, fmt.Sprintf("%s/%s", solver.HTTPChallengePath, token), key) + ok, err := s.testReachability(ctx, domain, fmt.Sprintf("%s/%s", solver.HTTPChallengePath, token), key) if err != nil { return false, err } + if !ok { + return false, nil + } time.Sleep(time.Second * 2) } return true, nil @@ -402,7 +405,7 @@ func (s *Solver) Check(domain, token, key string) (bool, error) { // testReachability will attempt to connect to the 'domain' with 'path' and // check if the returned body equals 'key' -func testReachability(ctx context.Context, domain, path, key string) error { +func testReachability(ctx context.Context, domain, path, key string) (bool, error) { url := &url.URL{} url.Scheme = "http" url.Host = domain @@ -410,24 +413,27 @@ func testReachability(ctx context.Context, domain, path, key string) error { response, err := http.Get(url.String()) if err != nil { - return err + return false, err } if response.StatusCode != http.StatusOK { - return fmt.Errorf("wrong status code '%d'", response.StatusCode) + // TODO: log this elsewhere + glog.Infof("wrong status code '%d'", response.StatusCode) + return false, nil } defer response.Body.Close() presentedKey, err := ioutil.ReadAll(response.Body) if err != nil { - return errors.New("unable to read body") + return false, err } if string(presentedKey) != key { - return fmt.Errorf("presented key (%s) did not match expected (%s)", presentedKey, key) + glog.Infof("presented key (%s) did not match expected (%s)", presentedKey, key) + return false, nil } - return nil + return false, nil } // CleanUp will ensure the created service and ingress are clean/deleted of any diff --git a/pkg/issuer/acme/http/http_test.go b/pkg/issuer/acme/http/http_test.go index ac7cd6305..c94a87ed4 100644 --- a/pkg/issuer/acme/http/http_test.go +++ b/pkg/issuer/acme/http/http_test.go @@ -26,9 +26,8 @@ func TestWait(t *testing.T) { type testT struct { name string reachabilityTest func(ctx context.Context, domain, path, key string) error - ctx context.Context domain, token, key string - expectedErr error + expectedErr bool } tests := []testT{ { @@ -36,15 +35,13 @@ func TestWait(t *testing.T) { reachabilityTest: func(context.Context, string, string, string) error { return nil }, - ctx: contextWithTimeout(time.Second * 30), }, { - name: "should timeout", + name: "should fail", reachabilityTest: func(context.Context, string, string, string) error { return fmt.Errorf("failed") }, - expectedErr: fmt.Errorf("context deadline exceeded"), - ctx: contextWithTimeout(time.Second * 30), + expectedErr: true, }, } @@ -58,22 +55,16 @@ func TestWait(t *testing.T) { requiredPasses: requiredCallsForPass, } - err := s.Wait(test.ctx, nil, test.domain, test.token, test.key) - if err != nil && test.expectedErr == nil { - t.Errorf("Expected Wait to return non-nil error, but got %v", err) + err := s.Check(test.domain, test.token, test.key) + if err != nil && !test.expectedErr { + t.Errorf("Expected Check to return non-nil error, but got %v", err) return } - if err != nil && test.expectedErr != nil { - if err.Error() != test.expectedErr.Error() { - t.Errorf("Expected error %v from Wait, but got: %v", test.expectedErr, err) - return - } - } - if err == nil && test.expectedErr != nil { - t.Errorf("Expected error %v from Wait, but got none", test.expectedErr) + if err == nil && test.expectedErr { + t.Errorf("Expected error from Check, but got none") return } - if test.expectedErr == nil && calls != requiredCallsForPass { + if test.expectedErr == false && calls != requiredCallsForPass { t.Errorf("Expected Wait to verify reachability test passes %d times, but only checked %d", requiredCallsForPass, calls) return }