cert-manager/pkg/issuer/acme/dns/azuredns/azuredns_test.go
Bartosz Slawianowski 30d4fce8a8 Add test case
Signed-off-by: Bartosz Slawianowski <bartosz.slawianowski@natzka.com>
2024-07-16 18:28:06 +02:00

426 lines
16 KiB
Go

// +skip_license_check
/*
This file contains portions of code directly taken from the 'xenolf/lego' project.
A copy of the license for this code can be found in the file named LICENSE in
this directory.
*/
package azuredns
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/arm"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/rand"
v1 "github.com/cert-manager/cert-manager/pkg/apis/acme/v1"
"github.com/cert-manager/cert-manager/pkg/issuer/acme/dns/util"
)
var (
azureLiveTest bool
azureClientID string
azureClientSecret string
azuresubscriptionID string
azureTenantID string
azureResourceGroupName string
azureHostedZoneName string
azureDomain string
)
func init() {
azureClientID = os.Getenv("AZURE_CLIENT_ID")
azureClientSecret = os.Getenv("AZURE_CLIENT_SECRET")
azuresubscriptionID = os.Getenv("AZURE_SUBSCRIPTION_ID")
azureTenantID = os.Getenv("AZURE_TENANT_ID")
azureResourceGroupName = os.Getenv("AZURE_RESOURCE_GROUP")
azureHostedZoneName = os.Getenv("AZURE_ZONE_NAME")
azureDomain = os.Getenv("AZURE_DOMAIN")
if len(azureClientID) > 0 && len(azureClientSecret) > 0 && len(azureDomain) > 0 {
azureLiveTest = true
}
}
func TestLiveAzureDnsPresent(t *testing.T) {
if !azureLiveTest {
t.Skip("skipping live test")
}
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.NoError(t, err)
err = provider.Present(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
assert.NoError(t, err)
}
func TestLiveAzureDnsPresentMultiple(t *testing.T) {
if !azureLiveTest {
t.Skip("skipping live test")
}
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.NoError(t, err)
err = provider.Present(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
assert.NoError(t, err)
err = provider.Present(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "1123d==")
assert.NoError(t, err)
}
func TestLiveAzureDnsCleanUp(t *testing.T) {
if !azureLiveTest {
t.Skip("skipping live test")
}
time.Sleep(time.Second * 5)
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.NoError(t, err)
err = provider.CleanUp(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
assert.NoError(t, err)
}
func TestLiveAzureDnsCleanUpMultiple(t *testing.T) {
if !azureLiveTest {
t.Skip("skipping live test")
}
time.Sleep(time.Second * 10)
provider, err := NewDNSProviderCredentials("", azureClientID, azureClientSecret, azuresubscriptionID, azureTenantID, azureResourceGroupName, azureHostedZoneName, util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.NoError(t, err)
err = provider.CleanUp(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "123d==")
assert.NoError(t, err)
err = provider.CleanUp(context.TODO(), azureDomain, "_acme-challenge."+azureDomain+".", "1123d==")
assert.NoError(t, err)
}
func TestInvalidAzureDns(t *testing.T) {
validEnv := []string{"", "AzurePublicCloud", "AzureChinaCloud", "AzureUSGovernmentCloud"}
for _, env := range validEnv {
_, err := NewDNSProviderCredentials(env, "cid", "secret", "", "tenid", "", "", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.NoError(t, err)
}
// Invalid environment
_, err := NewDNSProviderCredentials("invalid env", "cid", "secret", "", "tenid", "", "", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.Error(t, err)
// Invalid tenantID
_, err = NewDNSProviderCredentials("", "cid", "secret", "", "invalid env value", "", "", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.Error(t, err)
}
func TestAuthenticationError(t *testing.T) {
provider, err := NewDNSProviderCredentials("", "invalid-client-id", "invalid-client-secret", "subid", "tenid", "rg", "example.com", util.RecursiveNameservers, false, &v1.AzureManagedIdentity{})
assert.NoError(t, err)
err = provider.Present(context.TODO(), "example.com", "_acme-challenge.example.com.", "123d==")
assert.Error(t, err)
err = provider.CleanUp(context.TODO(), "example.com", "_acme-challenge.example.com.", "123d==")
assert.Error(t, err)
}
func populateFederatedToken(t *testing.T, filename string, content string) {
t.Helper()
f, err := os.Create(filename)
if err != nil {
assert.FailNow(t, err.Error())
}
if _, err := io.WriteString(f, content); err != nil {
assert.FailNow(t, err.Error())
}
if err := f.Close(); err != nil {
assert.FailNow(t, err.Error())
}
}
func TestGetAuthorizationFederatedSPT(t *testing.T) {
// Create a file that will be used to store a federated token
f, err := os.CreateTemp("", "")
if err != nil {
assert.FailNow(t, err.Error())
}
defer os.Remove(f.Name())
// Close the file to simplify logic within populateFederatedToken helper
if err := f.Close(); err != nil {
assert.FailNow(t, err.Error())
}
// The initial federated token is never used, so we don't care about the value yet
// Though, it's a requirement from adal to have a non-empty value set
populateFederatedToken(t, f.Name(), "random-jwt")
// Prepare environment variables adal will rely on. Skip changes for some envs if they are already defined (=live environment)
// Envs themselves are described here: https://azure.github.io/azure-workload-identity/docs/installation/mutating-admission-webhook.html
if os.Getenv("AZURE_TENANT_ID") == "" {
// TODO(wallrj): This is a hack. It is a quick way to `DisableInstanceDiscovery` during tests,
// to avoid the client attempting to connect to https://login.microsoftonline.com/common/discovery/instance.
// It works because there is a special case in azure-sdk-for-go which
// disables the instance discovery when the tenant ID is `adfs`. See:
// https://github.com/Azure/azure-sdk-for-go/blob/7288bda422654bde520a09034dd755b8f2dd4168/sdk/azidentity/public_client.go#L237-L239
// https://learn.microsoft.com/en-us/windows-server/identity/ad-fs/ad-fs-overview
//
// Find a better way to test this code.
t.Setenv("AZURE_TENANT_ID", "adfs")
}
if os.Getenv("AZURE_CLIENT_ID") == "" {
t.Setenv("AZURE_CLIENT_ID", "fakeClientID")
}
t.Setenv("AZURE_FEDERATED_TOKEN_FILE", f.Name())
t.Run("token refresh", func(t *testing.T) {
// Basically, we want one token to be exchanged for the other (key and value respectively)
tokens := map[string]string{
"initialFederatedToken": "initialAccessToken",
"refreshedFederatedToken": "refreshedAccessToken",
}
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, "/.well-known/openid-configuration") {
tenantURL := strings.TrimSuffix("https://"+r.Host+r.RequestURI, "/.well-known/openid-configuration")
w.Header().Set("Content-Type", "application/json")
openidConfiguration := map[string]string{
"token_endpoint": tenantURL + "/oauth2/token",
"authorization_endpoint": tenantURL + "/oauth2/authorize",
"issuer": "https://fakeIssuer.com",
}
if err := json.NewEncoder(w).Encode(openidConfiguration); err != nil {
assert.FailNow(t, err.Error())
}
return
}
if err := r.ParseForm(); err != nil {
assert.FailNow(t, err.Error())
}
w.Header().Set("Content-Type", "application/json")
receivedFederatedToken := r.FormValue("client_assertion")
accessToken := map[string]string{
"access_token": tokens[receivedFederatedToken],
}
if err := json.NewEncoder(w).Encode(accessToken); err != nil {
assert.FailNow(t, err.Error())
}
// Expected format: http://<server>/<tenant-ID>/oauth2/token?api-version=1.0
assert.Contains(t, r.RequestURI, strings.ToLower(os.Getenv("AZURE_TENANT_ID")), "URI should contain the tenant ID exposed through env variable")
assert.Equal(t, os.Getenv("AZURE_CLIENT_ID"), r.FormValue("client_id"), "client_id should match the value exposed through env variable")
}))
defer ts.Close()
ambient := true
clientOpt := policy.ClientOptions{
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: ts.URL},
Transport: ts.Client(),
}
managedIdentity := &v1.AzureManagedIdentity{ClientID: ""}
spt, err := getAuthorization(clientOpt, "", "", "", ambient, managedIdentity)
assert.NoError(t, err)
for federatedToken, accessToken := range tokens {
populateFederatedToken(t, f.Name(), federatedToken)
token, err := spt.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: []string{"test"}})
assert.NoError(t, err)
assert.Equal(t, accessToken, token.Token, "Access token should have been set to a value returned by the webserver")
// Overwrite the expires field to force the token to be re-read.
newExpires := time.Now().Add(-1 * time.Second)
v := reflect.ValueOf(spt.(*azidentity.WorkloadIdentityCredential)).Elem()
expiresField := v.FieldByName("expires")
reflect.NewAt(expiresField.Type(), expiresField.Addr().UnsafePointer()).
Elem().Set(reflect.ValueOf(newExpires))
}
})
t.Run("clientID overrides through managedIdentity section", func(t *testing.T) {
managedIdentity := &v1.AzureManagedIdentity{ClientID: "anotherClientID"}
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, "/.well-known/openid-configuration") {
tenantURL := strings.TrimSuffix("https://"+r.Host+r.RequestURI, "/.well-known/openid-configuration")
w.Header().Set("Content-Type", "application/json")
openidConfiguration := map[string]string{
"token_endpoint": tenantURL + "/oauth2/token",
"authorization_endpoint": tenantURL + "/oauth2/authorize",
"issuer": "https://fakeIssuer.com",
}
if err := json.NewEncoder(w).Encode(openidConfiguration); err != nil {
assert.FailNow(t, err.Error())
}
return
}
if err := r.ParseForm(); err != nil {
assert.FailNow(t, err.Error())
}
w.Header().Set("Content-Type", "application/json")
accessToken := map[string]string{
"access_token": "abc",
}
if err := json.NewEncoder(w).Encode(accessToken); err != nil {
assert.FailNow(t, err.Error())
}
assert.Equal(t, managedIdentity.ClientID, r.FormValue("client_id"), "client_id should match the value passed through managedIdentity section")
w.WriteHeader(http.StatusOK)
}))
defer ts.Close()
ambient := true
clientOpt := policy.ClientOptions{
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: ts.URL},
Transport: ts.Client(),
}
spt, err := getAuthorization(clientOpt, "", "", "", ambient, managedIdentity)
assert.NoError(t, err)
token, err := spt.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: []string{"test"}})
assert.NoError(t, err)
assert.NotEmpty(t, token.Token, "Access token should have been set to a value returned by the webserver")
})
// This test tests the stabilizeError function, it makes sure that authentication errors
// are also made stable. We want our error messages to be the same when the cause
// is the same to avoid spurious challenge updates.
// Specifically, this test makes sure that the errors of type AuthenticationFailedError
// are made stable. These errors are returned by the recordClient and zoneClient when
// they fail to authenticate. We simulate this by calling the GetToken function and
// returning a 502 Bad Gateway error.
t.Run("errors should be made stable", func(t *testing.T) {
managedIdentity := &v1.AzureManagedIdentity{ClientID: "anotherClientID"}
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, "/.well-known/openid-configuration") {
tenantURL := strings.TrimSuffix("https://"+r.Host+r.RequestURI, "/.well-known/openid-configuration")
w.Header().Set("Content-Type", "application/json")
openidConfiguration := map[string]string{
"token_endpoint": tenantURL + "/oauth2/token",
"authorization_endpoint": tenantURL + "/oauth2/authorize",
"issuer": "https://fakeIssuer.com",
}
if err := json.NewEncoder(w).Encode(openidConfiguration); err != nil {
assert.FailNow(t, err.Error())
}
return
}
w.WriteHeader(http.StatusBadGateway)
randomMessage := "test error message: " + rand.String(10)
payload := fmt.Sprintf(`{"error":{"code":"TEST_ERROR_CODE","message":"%s"}}`, randomMessage)
if _, err := w.Write([]byte(payload)); err != nil {
assert.FailNow(t, err.Error())
}
}))
defer ts.Close()
ambient := true
clientOpt := policy.ClientOptions{
Cloud: cloud.Configuration{ActiveDirectoryAuthorityHost: ts.URL},
Transport: ts.Client(),
}
spt, err := getAuthorization(clientOpt, "", "", "", ambient, managedIdentity)
assert.NoError(t, err)
_, err = spt.GetToken(context.TODO(), policy.TokenRequestOptions{Scopes: []string{"test"}})
err = stabilizeError(err)
assert.Error(t, err)
assert.ErrorContains(t, err, fmt.Sprintf(`authentication failed:
POST %s/adfs/oauth2/token
--------------------------------------------------------------------------------
RESPONSE 502 Bad Gateway
--------------------------------------------------------------------------------
see logs for more information`, ts.URL))
})
}
// TestStabilizeResponseError tests that the ResponseError errors returned by the AzureDNS API are
// changed to be stable. We want our error messages to be the same when the cause
// is the same to avoid spurious challenge updates.
func TestStabilizeResponseError(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
randomMessage := "test error message: " + rand.String(10)
payload := fmt.Sprintf(`{"error":{"code":"TEST_ERROR_CODE","message":"%s"}}`, randomMessage)
if _, err := w.Write([]byte(payload)); err != nil {
assert.FailNow(t, err.Error())
}
}))
defer ts.Close()
clientOpt := policy.ClientOptions{
Cloud: cloud.Configuration{
ActiveDirectoryAuthorityHost: ts.URL,
Services: map[cloud.ServiceName]cloud.ServiceConfiguration{
cloud.ResourceManager: {
Audience: ts.URL,
Endpoint: ts.URL,
},
},
},
Transport: ts.Client(),
}
zc, err := dns.NewZonesClient("subscriptionID", nil, &arm.ClientOptions{ClientOptions: clientOpt})
require.NoError(t, err)
dnsProvider := DNSProvider{
dns01Nameservers: util.RecursiveNameservers,
resourceGroupName: "resourceGroupName",
zoneClient: zc,
}
err = dnsProvider.Present(context.TODO(), "test.com", "fqdn.test.com.", "test123")
require.Error(t, err)
require.ErrorContains(t, err, fmt.Sprintf(`Zone test.com. not found in AzureDNS for domain fqdn.test.com.. Err: request error:
GET %s/subscriptions/subscriptionID/resourceGroups/resourceGroupName/providers/Microsoft.Network/dnsZones/test.com
--------------------------------------------------------------------------------
RESPONSE 502 Bad Gateway
ERROR CODE: TEST_ERROR_CODE
--------------------------------------------------------------------------------
see logs for more information`, ts.URL))
}