diff --git a/pkg/util/pki/match.go b/pkg/util/pki/match.go index b735cab9f..d01d76d1f 100644 --- a/pkg/util/pki/match.go +++ b/pkg/util/pki/match.go @@ -17,6 +17,7 @@ limitations under the License. package pki import ( + "bytes" "crypto" "crypto/ecdsa" "crypto/ed25519" @@ -191,20 +192,20 @@ func RequestMatchesSpec(req *cmapi.CertificateRequest, spec cmapi.CertificateSpe } } else { - // we have a LiteralSubject - // parse the subject of the csr in the same way as we parse LiteralSubject and see whether the RDN Sequences match - - rdnSequenceFromCertificateRequest, err := UnmarshalRawDerBytesToRDNSequence(x509req.RawSubject) - if err != nil { - return nil, err - } + // we have a LiteralSubject, generate the RDNSequence and encode it to compare + // with the request's subject rdnSequenceFromCertificate, err := UnmarshalSubjectStringToRDNSequence(spec.LiteralSubject) if err != nil { return nil, err } - if !reflect.DeepEqual(rdnSequenceFromCertificate, rdnSequenceFromCertificateRequest) { + asn1Sequence, err := asn1.Marshal(rdnSequenceFromCertificate) + if err != nil { + return nil, err + } + + if !bytes.Equal(x509req.RawSubject, asn1Sequence) { violations = append(violations, "spec.literalSubject") } } diff --git a/pkg/util/pki/match_test.go b/pkg/util/pki/match_test.go index 50516aff4..43c3e2fe0 100644 --- a/pkg/util/pki/match_test.go +++ b/pkg/util/pki/match_test.go @@ -20,6 +20,7 @@ import ( "bytes" "crypto" "crypto/x509" + "encoding/asn1" "encoding/pem" "reflect" "testing" @@ -225,6 +226,92 @@ func TestCertificateRequestOtherNamesMatchSpec(t *testing.T) { } } +func TestRequestMatchesSpecSubject(t *testing.T) { + createCSRBlob := func(literalSubject string) []byte { + pk, err := GenerateRSAPrivateKey(2048) + if err != nil { + t.Fatal(err) + } + + seq, err := UnmarshalSubjectStringToRDNSequence(literalSubject) + if err != nil { + t.Fatal(err) + } + + asn1Seq, err := asn1.Marshal(seq) + if err != nil { + t.Fatal(err) + } + + csr := &x509.CertificateRequest{ + RawSubject: asn1Seq, + } + + csrBytes, err := x509.CreateCertificateRequest(bytes.NewBuffer(nil), csr, pk) + if err != nil { + t.Fatal(err) + } + + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes}) + } + + tests := []struct { + name string + subject *cmapi.X509Subject + literalSubject string + x509CSR []byte + err string + violations []string + }{ + { + name: "Matching LiteralSubjects", + literalSubject: "CN=example.com,OU=example,O=example,L=example,ST=example,C=US", + x509CSR: createCSRBlob("CN=example.com,OU=example,O=example,L=example,ST=example,C=US"), + }, + { + name: "Matching LiteralSubjects", + literalSubject: "ST=example,C=US", + x509CSR: createCSRBlob("ST=example"), + violations: []string{"spec.literalSubject"}, + }, + { + name: "Matching LiteralSubjects", + literalSubject: "ST=example,C=US,O=#04024869", + x509CSR: createCSRBlob("ST=example,C=US,O=#04024869"), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + violations, err := RequestMatchesSpec( + &cmapi.CertificateRequest{ + Spec: cmapi.CertificateRequestSpec{ + Request: test.x509CSR, + }, + }, + cmapi.CertificateSpec{ + Subject: test.subject, + LiteralSubject: test.literalSubject, + }, + ) + if err != nil { + if test.err == "" { + t.Errorf("Unexpected error: %s", err.Error()) + } else { + if test.err != err.Error() { + t.Errorf("Expected error: %s but got: %s instead", err.Error(), test.err) + } + } + } + + if !reflect.DeepEqual(violations, test.violations) { + t.Errorf("violations did not match, got=%s, exp=%s", violations, test.violations) + } + }) + } +} + func TestSecretDataAltNamesMatchSpec(t *testing.T) { tests := map[string]struct { data []byte