don't roundtrip url into strings and back

Signed-off-by: Daniel Morsing <dmo@jetstack.io>
This commit is contained in:
Daniel Morsing 2019-01-17 12:46:01 +00:00
parent 3767a6df23
commit 62923a9ba8
2 changed files with 13 additions and 12 deletions

View File

@ -61,7 +61,7 @@ type Solver struct {
requiredPasses int
}
type reachabilityTest func(ctx context.Context, url, key string) (bool, error)
type reachabilityTest func(ctx context.Context, url *url.URL, key string) (bool, error)
// absorbErr wraps an error to mark it as absorbable (log and handle as nil)
type absorbErr struct {
@ -130,21 +130,21 @@ func (s *Solver) CleanUp(ctx context.Context, issuer v1alpha1.GenericIssuer, ch
return utilerrors.NewAggregate(errs)
}
func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) string {
func (s *Solver) buildChallengeUrl(ch *v1alpha1.Challenge) *url.URL {
url := &url.URL{}
url.Scheme = "http"
url.Host = ch.Spec.DNSName
url.Path = fmt.Sprintf("%s/%s", solver.HTTPChallengePath, ch.Spec.Token)
return url.String()
return url
}
// testReachability will attempt to connect to the 'domain' with 'path' and
// check if the returned body equals 'key'
func testReachability(ctx context.Context, url string, key string) (bool, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return false, fmt.Errorf("failed to build request: %v", err)
func testReachability(ctx context.Context, url *url.URL, key string) (bool, error) {
req := &http.Request{
Method: http.MethodGet,
URL: url,
}
req = req.WithContext(ctx)

View File

@ -19,6 +19,7 @@ package http
import (
"context"
"fmt"
"net/url"
"testing"
"github.com/jetstack/cert-manager/pkg/apis/certmanager/v1alpha1"
@ -27,7 +28,7 @@ import (
// countReachabilityTestCalls is a wrapper function that allows us to count the number
// of calls to a reachabilityTest.
func countReachabilityTestCalls(counter *int, t reachabilityTest) reachabilityTest {
return func(ctx context.Context, url, key string) (bool, error) {
return func(ctx context.Context, url *url.URL, key string) (bool, error) {
*counter++
return t(ctx, url, key)
}
@ -44,26 +45,26 @@ func TestCheck(t *testing.T) {
tests := []testT{
{
name: "should pass",
reachabilityTest: func(context.Context, string, string) (bool, error) {
reachabilityTest: func(context.Context, *url.URL, string) (bool, error) {
return true, nil
},
expectedOk: true,
},
{
name: "should fail",
reachabilityTest: func(context.Context, string, string) (bool, error) {
reachabilityTest: func(context.Context, *url.URL, string) (bool, error) {
return false, nil
},
},
{
name: "should fail with absorbed error",
reachabilityTest: func(context.Context, string, string) (bool, error) {
reachabilityTest: func(context.Context, *url.URL, string) (bool, error) {
return false, &absorbErr{err: fmt.Errorf("failed")}
},
},
{
name: "should error",
reachabilityTest: func(context.Context, string, string) (bool, error) {
reachabilityTest: func(context.Context, *url.URL, string) (bool, error) {
return false, fmt.Errorf("failed")
},
expectedErr: true,