diff --git a/pkg/util/pki/asn1_util.go b/pkg/util/pki/asn1_util.go index a2d02c17e..ebbe8fc02 100644 --- a/pkg/util/pki/asn1_util.go +++ b/pkg/util/pki/asn1_util.go @@ -48,6 +48,15 @@ func ParseObjectIdentifier(oidString string) (oid asn1.ObjectIdentifier, err err return oid, nil } +type UniversalValueType int + +const ( + UniversalValueTypeBytes UniversalValueType = iota + UniversalValueTypeIA5String + UniversalValueTypeUTF8String + UniversalValueTypePrintableString +) + type UniversalValue struct { Bytes []byte IA5String string @@ -55,50 +64,56 @@ type UniversalValue struct { PrintableString string } -func MarshalUniversalValue(uv UniversalValue) ([]byte, error) { - // Make sure we have only one field set - { - var count int - if uv.Bytes != nil { - count++ - } - if uv.IA5String != "" { - count++ - } - if uv.UTF8String != "" { - count++ - } - if uv.PrintableString != "" { - count++ - } - if count != 1 { - return nil, fmt.Errorf("exactly one field must be set") - } +func (uv UniversalValue) Type() UniversalValueType { + isBytes := uv.Bytes != nil + isIA5String := uv.IA5String != "" + isUTF8String := uv.UTF8String != "" + isPrintableString := uv.PrintableString != "" + + switch { + case isBytes && !isIA5String && !isUTF8String && !isPrintableString: + return UniversalValueTypeBytes + case !isBytes && isIA5String && !isUTF8String && !isPrintableString: + return UniversalValueTypeIA5String + case !isBytes && !isIA5String && isUTF8String && !isPrintableString: + return UniversalValueTypeUTF8String + case !isBytes && !isIA5String && !isUTF8String && isPrintableString: + return UniversalValueTypePrintableString } + return -1 // Either no field is set or two fields are set. +} + +func MarshalUniversalValue(uv UniversalValue) ([]byte, error) { + // Make sure we have only one field set + uvType := uv.Type() var bytes []byte - if uv.Bytes != nil { + switch uvType { + case -1: + return nil, errors.New("UniversalValue should have exactly one field set") + case UniversalValueTypeBytes: bytes = uv.Bytes - } else { + default: rawValue := asn1.RawValue{ Class: asn1.ClassUniversal, IsCompound: false, } - switch { - case uv.IA5String != "": + + switch uvType { + case UniversalValueTypeIA5String: if err := isIA5String(uv.IA5String); err != nil { return nil, errors.New("asn1: invalid IA5 string") } rawValue.Tag = asn1.TagIA5String rawValue.Bytes = []byte(uv.IA5String) - case uv.UTF8String != "": + case UniversalValueTypeUTF8String: if !utf8.ValidString(uv.UTF8String) { return nil, errors.New("asn1: invalid UTF-8 string") } rawValue.Tag = asn1.TagUTF8String rawValue.Bytes = []byte(uv.UTF8String) - case uv.PrintableString != "": + case UniversalValueTypePrintableString: if !isPrintable(uv.PrintableString) { return nil, errors.New("asn1: invalid PrintableString string") } diff --git a/pkg/util/pki/match.go b/pkg/util/pki/match.go index 7f5c83f64..b735cab9f 100644 --- a/pkg/util/pki/match.go +++ b/pkg/util/pki/match.go @@ -21,6 +21,8 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "crypto/x509/pkix" + "encoding/asn1" "net" "fmt" @@ -148,6 +150,16 @@ func RequestMatchesSpec(req *cmapi.CertificateRequest, spec cmapi.CertificateSpe violations = append(violations, "spec.dnsNames") } + if spec.OtherNames != nil { + matched, err := matchOtherNames(x509req.Extensions, spec.OtherNames) + if err != nil { + return nil, err + } + if !matched { + violations = append(violations, "spec.otherNames") + } + } + if spec.LiteralSubject == "" { // Comparing Subject fields if x509req.Subject.CommonName != spec.CommonName { @@ -216,6 +228,51 @@ func RequestMatchesSpec(req *cmapi.CertificateRequest, spec cmapi.CertificateSpe return violations, nil } +func matchOtherNames(extension []pkix.Extension, specOtherNames []cmapi.OtherName) (bool, error) { + x509SANExtension, err := extractSANExtension(extension) + if err != nil { + return false, nil + } + + x509GeneralNames, err := UnmarshalSANs(x509SANExtension.Value) + if err != nil { + return false, err + } + + x509OtherNames := make([]cmapi.OtherName, 0, len(x509GeneralNames.OtherNames)) + for _, otherName := range x509GeneralNames.OtherNames { + + var otherNameInnerValue asn1.RawValue + // We have to perform one more level of unwrapping because value is still context specific class + // tagged 0 + _, err := asn1.Unmarshal(otherName.Value.Bytes, &otherNameInnerValue) + if err != nil { + return false, err + } + + uv, err := UnmarshalUniversalValue(otherNameInnerValue) + if err != nil { + return false, err + } + + if uv.Type() != UniversalValueTypeUTF8String { + // This means the CertificateRequest's otherName was not an utf8 value + return false, fmt.Errorf("otherName is not an utf8 value, got: %v", uv.Type()) + } + + x509OtherNames = append(x509OtherNames, cmapi.OtherName{ + OID: otherName.TypeID.String(), + UTF8Value: uv.UTF8String, + }) + } + + if !util.EqualOtherNamesUnsorted(x509OtherNames, specOtherNames) { + return false, nil + } + + return true, nil +} + // SecretDataAltNamesMatchSpec will compare a Secret resource containing certificate // data to a CertificateSpec and return a list of 'violations' for any fields that // do not match their counterparts. @@ -267,3 +324,15 @@ func SecretDataAltNamesMatchSpec(secret *corev1.Secret, spec cmapi.CertificateSp return violations, nil } + +func extractSANExtension(extensions []pkix.Extension) (pkix.Extension, error) { + oidExtensionSubjectAltName := []int{2, 5, 29, 17} + + for _, extension := range extensions { + if extension.Id.Equal(oidExtensionSubjectAltName) { + return extension, nil + } + } + + return pkix.Extension{}, fmt.Errorf("SAN extension not present!") +} diff --git a/pkg/util/pki/match_test.go b/pkg/util/pki/match_test.go index e9d961787..50516aff4 100644 --- a/pkg/util/pki/match_test.go +++ b/pkg/util/pki/match_test.go @@ -17,11 +17,15 @@ limitations under the License. package pki import ( + "bytes" "crypto" + "crypto/x509" + "encoding/pem" "reflect" "testing" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" cmapi "github.com/cert-manager/cert-manager/pkg/apis/certmanager/v1" ) @@ -119,6 +123,108 @@ func TestPrivateKeyMatchesSpec(t *testing.T) { } } +func TestCertificateRequestOtherNamesMatchSpec(t *testing.T) { + tests := map[string]struct { + crSpec *cmapi.CertificateRequest + certSpec cmapi.CertificateSpec + err string + violations []string + }{ + "should not report any violation if Certificate otherName(s) match the CertificateRequest's": { + crSpec: MustBuildCertificateRequest(&cmapi.Certificate{Spec: cmapi.CertificateSpec{ + CommonName: "cn", + OtherNames: []cmapi.OtherName{ + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "upn@testdomain.local", + }, + }, + }}, t), + certSpec: cmapi.CertificateSpec{ + CommonName: "cn", + OtherNames: []cmapi.OtherName{ + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "upn@testdomain.local", + }, + }, + }, + err: "", + }, + "should report violation if Certificate otherName(s) mismatch the CertificateRequest's": { + crSpec: MustBuildCertificateRequest(&cmapi.Certificate{Spec: cmapi.CertificateSpec{ + CommonName: "cn", + OtherNames: []cmapi.OtherName{ + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "upn@testdomain.local", + }, + }, + }}, t), + certSpec: cmapi.CertificateSpec{ + CommonName: "cn", + OtherNames: []cmapi.OtherName{ + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "upn2@testdomain.local", + }, + }, + }, + err: "", + violations: []string{ + "spec.otherNames", + }, + }, + "should not report violation if Certificate otherName(s) match the CertificateRequest's (with different order)": { + crSpec: MustBuildCertificateRequest(&cmapi.Certificate{Spec: cmapi.CertificateSpec{ + CommonName: "cn", + OtherNames: []cmapi.OtherName{ + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "anotherupn@testdomain.local", + }, + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "upn@testdomain.local", + }, + }, + }}, t), + certSpec: cmapi.CertificateSpec{ + CommonName: "cn", + OtherNames: []cmapi.OtherName{ + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "upn@testdomain.local", + }, + { + OID: "1.3.6.1.4.1.311.20.2.3", + UTF8Value: "anotherupn@testdomain.local", + }, + }, + }, + err: "", + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + violations, err := RequestMatchesSpec(test.crSpec, test.certSpec) + if err != nil { + if test.err == "" { + t.Errorf("Unexpected error: %s", err.Error()) + } else { + if test.err != err.Error() { + t.Errorf("Expected error: %s but got: %s instead", err.Error(), test.err) + } + } + } + + if !reflect.DeepEqual(violations, test.violations) { + t.Errorf("violations did not match, got=%s, exp=%s", violations, test.violations) + } + }) + } +} + func TestSecretDataAltNamesMatchSpec(t *testing.T) { tests := map[string]struct { data []byte @@ -289,3 +395,38 @@ func selfSignCertificate(t *testing.T, spec cmapi.CertificateSpec) []byte { return pemData } + +func MustBuildCertificateRequest(crt *cmapi.Certificate, t *testing.T) *cmapi.CertificateRequest { + pk, err := GenerateRSAPrivateKey(2048) + if err != nil { + t.Fatal(err) + } + + csrTemplate, err := GenerateCSR(crt, WithOtherNames(true)) + if err != nil { + t.Fatal(err) + } + + var buffer bytes.Buffer + csr, err := x509.CreateCertificateRequest(&buffer, csrTemplate, pk) + if err != nil { + t.Fatal(err) + } + pemData := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csr}) + cr := &cmapi.CertificateRequest{ + ObjectMeta: metav1.ObjectMeta{ + Name: t.Name(), + Annotations: crt.Annotations, + Labels: crt.Labels, + }, + Spec: cmapi.CertificateRequestSpec{ + Request: pemData, + Duration: crt.Spec.Duration, + IssuerRef: crt.Spec.IssuerRef, + IsCA: crt.Spec.IsCA, + Usages: crt.Spec.Usages, + }, + } + + return cr +} diff --git a/pkg/util/util.go b/pkg/util/util.go index 868e862fc..30a4030c5 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -85,6 +85,17 @@ func EqualURLsUnsorted(s1, s2 []*url.URL) bool { }) } +// Test for equal cmapi.OtherName slices even if unsorted. Panics if any element is nil +func EqualOtherNamesUnsorted(s1, s2 []cmapi.OtherName) bool { + return genericEqualUnsorted(s1, s2, func(a cmapi.OtherName, b cmapi.OtherName) int { + if a.OID == b.OID { + return strings.Compare(a.UTF8Value, b.UTF8Value) + } + return strings.Compare(a.OID, b.OID) + }) + +} + // EqualIPsUnsorted checks if the given slices of IP addresses contain the same elements, even if in a different order func EqualIPsUnsorted(s1, s2 []net.IP) bool { // Two IPv4 addresses can compare unequal with bytes.Equal which is why net.IP.Equal exists.