415 lines
15 KiB
Go
415 lines
15 KiB
Go
// +skip_license_check
|
|
|
|
/*
|
|
This file contains portions of code directly taken from the 'xenolf/lego' project.
|
|
A copy of the license for this code can be found in the file named LICENSE in
|
|
this directory.
|
|
*/
|
|
|
|
package azuredns
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
|
dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"k8s.io/apimachinery/pkg/util/rand"
|
|
|
|
v1 "github.com/cert-manager/cert-manager/pkg/apis/acme/v1"
|
|
"github.com/cert-manager/cert-manager/pkg/issuer/acme/dns/util"
|
|
)
|
|
|
|
var (
|
|
azureLiveTest bool
|
|
azureClientID string
|
|
azureClientSecret string
|
|
azuresubscriptionID string
|
|
azureTenantID string
|
|
azureResourceGroupName string
|
|
azureHostedZoneName string
|
|
azureDomain string
|
|
)
|
|
|
|
func init() {
|
|
azureClientID = os.Getenv("AZURE_CLIENT_ID")
|
|
azureClientSecret = os.Getenv("AZURE_CLIENT_SECRET")
|
|
azuresubscriptionID = os.Getenv("AZURE_SUBSCRIPTION_ID")
|
|
azureTenantID = os.Getenv("AZURE_TENANT_ID")
|
|
azureResourceGroupName = os.Getenv("AZURE_RESOURCE_GROUP")
|
|
azureHostedZoneName = os.Getenv("AZURE_ZONE_NAME")
|
|
azureDomain = os.Getenv("AZURE_DOMAIN")
|
|
if len(azureClientID) > 0 && len(azureClientSecret) > 0 && len(azureDomain) > 0 {
|
|
azureLiveTest = true
|
|
}
|
|
}
|
|
|
|
func TestLiveAzureDnsPresent(t *testing.T) {
|
|
if !azureLiveTest {
|
|
t.Skip("skipping live test")
|
|
}
|
|
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.NoError(t, err)
|
|
|
|
err = provider.Present(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestLiveAzureDnsPresentMultiple(t *testing.T) {
|
|
if !azureLiveTest {
|
|
t.Skip("skipping live test")
|
|
}
|
|
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.NoError(t, err)
|
|
|
|
err = provider.Present(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
|
|
assert.NoError(t, err)
|
|
err = provider.Present(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "1123d==")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestLiveAzureDnsCleanUp(t *testing.T) {
|
|
if !azureLiveTest {
|
|
t.Skip("skipping live test")
|
|
}
|
|
|
|
time.Sleep(time.Second * 5)
|
|
|
|
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.NoError(t, err)
|
|
|
|
err = provider.CleanUp(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestLiveAzureDnsCleanUpMultiple(t *testing.T) {
|
|
if !azureLiveTest {
|
|
t.Skip("skipping live test")
|
|
}
|
|
|
|
time.Sleep(time.Second * 10)
|
|
|
|
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.NoError(t, err)
|
|
|
|
err = provider.CleanUp(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
|
|
assert.NoError(t, err)
|
|
err = provider.CleanUp(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "1123d==")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestInvalidAzureDns(t *testing.T) {
|
|
validEnv := []string{"", "AzurePublicCloud", "AzureChinaCloud", "AzureUSGovernmentCloud"}
|
|
for _, env := range validEnv {
|
|
_, err := NewDNSProviderCredentials(env, "cid", "secret", "", "tenid", "", "", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
// Invalid environment
|
|
_, err := NewDNSProviderCredentials("invalid env", "cid", "secret", "", "tenid", "", "", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.Error(t, err)
|
|
|
|
// Invalid tenantID
|
|
_, err = NewDNSProviderCredentials("", "cid", "secret", "", "invalid env value", "", "", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
func populateFederatedToken(t *testing.T, filename string, content string) {
|
|
t.Helper()
|
|
|
|
f, err := os.Create(filename)
|
|
if err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
if _, err := io.WriteString(f, content); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
if err := f.Close(); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
}
|
|
|
|
func TestGetAuthorizationFederatedSPT(t *testing.T) {
|
|
// Create a file that will be used to store a federated token
|
|
f, err := os.CreateTemp("", "")
|
|
if err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
defer os.Remove(f.Name())
|
|
|
|
// Close the file to simplify logic within populateFederatedToken helper
|
|
if err := f.Close(); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
// The initial federated token is never used, so we don't care about the value yet
|
|
// Though, it's a requirement from adal to have a non-empty value set
|
|
populateFederatedToken(t, f.Name(), "random-jwt")
|
|
|
|
// Prepare environment variables adal will rely on. Skip changes for some envs if they are already defined (=live environment)
|
|
// Envs themselves are described here: https://azure.github.io/azure-workload-identity/docs/installation/mutating-admission-webhook.html
|
|
if os.Getenv("AZURE_TENANT_ID") == "" {
|
|
// TODO(wallrj): This is a hack. It is a quick way to `DisableInstanceDiscovery` during tests,
|
|
// to avoid the client attempting to connect to https://login.microsoftonline.com/common/discovery/instance.
|
|
// It works because there is a special case in azure-sdk-for-go which
|
|
// disables the instance discovery when the tenant ID is `adfs`. See:
|
|
// https://github.com/Azure/azure-sdk-for-go/blob/7288bda422654bde520a09034dd755b8f2dd4168/sdk/azidentity/public_client.go#L237-L239
|
|
// https://learn.microsoft.com/en-us/windows-server/identity/ad-fs/ad-fs-overview
|
|
//
|
|
// Find a better way to test this code.
|
|
t.Setenv("AZURE_TENANT_ID", "adfs")
|
|
}
|
|
|
|
if os.Getenv("AZURE_CLIENT_ID") == "" {
|
|
t.Setenv("AZURE_CLIENT_ID", "fakeClientID")
|
|
}
|
|
|
|
t.Setenv("AZURE_FEDERATED_TOKEN_FILE", f.Name())
|
|
|
|
t.Run("token refresh", func(t *testing.T) {
|
|
// Basically, we want one token to be exchanged for the other (key and value respectively)
|
|
tokens := map[string]string{
|
|
"initialFederatedToken": "initialAccessToken",
|
|
"refreshedFederatedToken": "refreshedAccessToken",
|
|
}
|
|
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.HasSuffix(r.RequestURI, "/.well-known/openid-configuration") {
|
|
tenantURL := strings.TrimSuffix("https://"+r.Host+r.RequestURI, "/.well-known/openid-configuration")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
openidConfiguration := map[string]string{
|
|
"token_endpoint": tenantURL + "/oauth2/token",
|
|
"authorization_endpoint": tenantURL + "/oauth2/authorize",
|
|
"issuer": "https://fakeIssuer.com",
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(openidConfiguration); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
if err := r.ParseForm(); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
receivedFederatedToken := r.FormValue("client_assertion")
|
|
accessToken := map[string]string{
|
|
"access_token": tokens[receivedFederatedToken],
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(accessToken); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
// Expected format: http://<server>/<tenant-ID>/oauth2/token?api-version=1.0
|
|
assert.Contains(t, r.RequestURI, strings.ToLower(os.Getenv("AZURE_TENANT_ID")), "URI should contain the tenant ID exposed through env variable")
|
|
|
|
assert.Equal(t, os.Getenv("AZURE_CLIENT_ID"), r.FormValue("client_id"), "client_id should match the value exposed through env variable")
|
|
}))
|
|
defer ts.Close()
|
|
|
|
ambient := true
|
|
clientOpt := policy.ClientOptions{
|
|
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: ts.URL},
|
|
Transport: ts.Client(),
|
|
}
|
|
managedIdentity := &v1.AzureManagedIdentity{ClientID: ""}
|
|
|
|
spt, err := getAuthorization(clientOpt, "", "", "", ambient, managedIdentity)
|
|
assert.NoError(t, err)
|
|
|
|
for federatedToken, accessToken := range tokens {
|
|
populateFederatedToken(t, f.Name(), federatedToken)
|
|
token, err := spt.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: []string{"test"}})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, accessToken, token.Token, "Access token should have been set to a value returned by the webserver")
|
|
|
|
// Overwrite the expires field to force the token to be re-read.
|
|
newExpires := time.Now().Add(-1 * time.Second)
|
|
v := reflect.ValueOf(spt.(*azidentity.WorkloadIdentityCredential)).Elem()
|
|
expiresField := v.FieldByName("expires")
|
|
reflect.NewAt(expiresField.Type(), expiresField.Addr().UnsafePointer()).
|
|
Elem().Set(reflect.ValueOf(newExpires))
|
|
}
|
|
})
|
|
|
|
t.Run("clientID overrides through managedIdentity section", func(t *testing.T) {
|
|
managedIdentity := &v1.AzureManagedIdentity{ClientID: "anotherClientID"}
|
|
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.HasSuffix(r.RequestURI, "/.well-known/openid-configuration") {
|
|
tenantURL := strings.TrimSuffix("https://"+r.Host+r.RequestURI, "/.well-known/openid-configuration")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
openidConfiguration := map[string]string{
|
|
"token_endpoint": tenantURL + "/oauth2/token",
|
|
"authorization_endpoint": tenantURL + "/oauth2/authorize",
|
|
"issuer": "https://fakeIssuer.com",
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(openidConfiguration); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
if err := r.ParseForm(); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
accessToken := map[string]string{
|
|
"access_token": "abc",
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(accessToken); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
assert.Equal(t, managedIdentity.ClientID, r.FormValue("client_id"), "client_id should match the value passed through managedIdentity section")
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
ambient := true
|
|
clientOpt := policy.ClientOptions{
|
|
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: ts.URL},
|
|
Transport: ts.Client(),
|
|
}
|
|
|
|
spt, err := getAuthorization(clientOpt, "", "", "", ambient, managedIdentity)
|
|
assert.NoError(t, err)
|
|
|
|
token, err := spt.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: []string{"test"}})
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, token.Token, "Access token should have been set to a value returned by the webserver")
|
|
})
|
|
|
|
// This test tests the stabilizeError function, it makes sure that authentication errors
|
|
// are also made stable. We want our error messages to be the same when the cause
|
|
// is the same to avoid spurious challenge updates.
|
|
// Specifically, this test makes sure that the errors of type AuthenticationFailedError
|
|
// are made stable. These errors are returned by the recordClient and zoneClient when
|
|
// they fail to authenticate. We simulate this by calling the GetToken function and
|
|
// returning a 502 Bad Gateway error.
|
|
t.Run("errors should be made stable", func(t *testing.T) {
|
|
managedIdentity := &v1.AzureManagedIdentity{ClientID: "anotherClientID"}
|
|
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.HasSuffix(r.RequestURI, "/.well-known/openid-configuration") {
|
|
tenantURL := strings.TrimSuffix("https://"+r.Host+r.RequestURI, "/.well-known/openid-configuration")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
openidConfiguration := map[string]string{
|
|
"token_endpoint": tenantURL + "/oauth2/token",
|
|
"authorization_endpoint": tenantURL + "/oauth2/authorize",
|
|
"issuer": "https://fakeIssuer.com",
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(openidConfiguration); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
randomMessage := "test error message: " + rand.String(10)
|
|
payload := fmt.Sprintf(`{"error":{"code":"TEST_ERROR_CODE","message":"%s"}}`, randomMessage)
|
|
if _, err := w.Write([]byte(payload)); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
}))
|
|
defer ts.Close()
|
|
|
|
ambient := true
|
|
clientOpt := policy.ClientOptions{
|
|
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: ts.URL},
|
|
Transport: ts.Client(),
|
|
}
|
|
|
|
spt, err := getAuthorization(clientOpt, "", "", "", ambient, managedIdentity)
|
|
assert.NoError(t, err)
|
|
|
|
_, err = spt.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: []string{"test"}})
|
|
err = stabilizeError(err)
|
|
assert.Error(t, err)
|
|
assert.ErrorContains(t, err, fmt.Sprintf(`authentication failed:
|
|
POST %s/adfs/oauth2/token
|
|
--------------------------------------------------------------------------------
|
|
RESPONSE 502 Bad Gateway
|
|
--------------------------------------------------------------------------------
|
|
see logs for more information`, ts.URL))
|
|
})
|
|
}
|
|
|
|
// TestStabilizeResponseError tests that the ResponseError errors returned by the AzureDNS API are
|
|
// changed to be stable. We want our error messages to be the same when the cause
|
|
// is the same to avoid spurious challenge updates.
|
|
func TestStabilizeResponseError(t *testing.T) {
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
randomMessage := "test error message: " + rand.String(10)
|
|
payload := fmt.Sprintf(`{"error":{"code":"TEST_ERROR_CODE","message":"%s"}}`, randomMessage)
|
|
if _, err := w.Write([]byte(payload)); err != nil {
|
|
assert.FailNow(t, err.Error())
|
|
}
|
|
}))
|
|
|
|
defer ts.Close()
|
|
|
|
clientOpt := policy.ClientOptions{
|
|
Cloud: cloud.Configuration{
|
|
ActiveDirectoryAuthorityHost: ts.URL,
|
|
Services: map[cloud.ServiceName]cloud.ServiceConfiguration{
|
|
cloud.ResourceManager: {
|
|
Audience: ts.URL,
|
|
Endpoint: ts.URL,
|
|
},
|
|
},
|
|
},
|
|
Transport: ts.Client(),
|
|
}
|
|
|
|
zc, err := dns.NewZonesClient("subscriptionID", nil, &arm.ClientOptions{ClientOptions: clientOpt})
|
|
require.NoError(t, err)
|
|
|
|
dnsProvider := DNSProvider{
|
|
dns01Nameservers: util.RecursiveNameservers,
|
|
resourceGroupName: "resourceGroupName",
|
|
zoneClient: zc,
|
|
}
|
|
|
|
err = dnsProvider.Present(context.TODO(), "test.com", "fqdn.test.com.", "test123")
|
|
require.Error(t, err)
|
|
require.ErrorContains(t, err, fmt.Sprintf(`Zone test.com. not found in AzureDNS for domain fqdn.test.com.. Err: request error:
|
|
GET %s/subscriptions/subscriptionID/resourceGroups/resourceGroupName/providers/Microsoft.Network/dnsZones/test.com
|
|
--------------------------------------------------------------------------------
|
|
RESPONSE 502 Bad Gateway
|
|
ERROR CODE: TEST_ERROR_CODE
|
|
--------------------------------------------------------------------------------
|
|
see logs for more information`, ts.URL))
|
|
}
|