diff --git a/pkg/util/pki/parse.go b/pkg/util/pki/parse.go index 219c36082..93572515a 100644 --- a/pkg/util/pki/parse.go +++ b/pkg/util/pki/parse.go @@ -140,3 +140,199 @@ func DecodeX509CertificateRequestBytes(csrBytes []byte) (*x509.CertificateReques return csr, nil } + +// PEMBundle includes the PEM encoded X.509 certificate chain and CA. CAPEM +// contains either 1 CA certificate, or is empty if only a single certificate +// exists in the chain. +type PEMBundle struct { + CAPEM []byte + ChainPEM []byte +} + +type chainNode struct { + cert *x509.Certificate + issuer *chainNode +} + +// ParseCertificateChainPEM decodes a PEM encoded certificate chain before +// calling ParseCertificateChain +func ParseCertificateChainPEM(pembundle []byte) (PEMBundle, error) { + certs, err := DecodeX509CertificateChainBytes(pembundle) + if err != nil { + return PEMBundle{}, err + } + return ParseCertificateChain(certs) +} + +// ParseCertificateChain returns the PEM-encoded chain of certificates as well +// as the PEM-encoded CA certificate. The certificate chain contains the +// leaf certificate first. +// +// The CA may not be a true root, but the highest intermediate certificate. +// The returned CA may be empty if a single certificate was passed. +// +// This function removes duplicate certificate entries as well as comments and +// unnecessary white space. +// +// An error is returned if the passed bundle is not a valid flat tree chain, +// the bundle is malformed, or the chain is broken. +func ParseCertificateChain(certs []*x509.Certificate) (PEMBundle, error) { + // De-duplicate certificates. This moves "complicated" logic away from + // consumers and into a shared function, who would otherwise have to do this + // anyway. + for i := 0; i < len(certs)-1; i++ { + for j := 1; j < len(certs); j++ { + if i == j { + continue + } + if certs[i].Equal(certs[j]) { + certs = append(certs[:i], certs[i+1:]...) + } + } + } + + // A certificate chain can be well described as a linked list. Here we build + // multiple lists that contain a single node, each being a single certificate + // that was passed. + var chains []*chainNode + for i := range certs { + chains = append(chains, &chainNode{cert: certs[i]}) + } + + // The task is to build a single list which represents a single certificate + // chain. The strategy is to iteratively attempt to join items in the list to + // build this single chain. Once we have a single list, we have built the + // chain. If the number of lists do not decrease after a pass, then the list + // can never be reduced to a single chain and we error. + for { + // If a single list is left, then we have built the entire chain. Stop + // iterating. + if len(chains) == 1 { + break + } + + // lastChainsLength is used to ensure that at every pass, the number of + // tested chains gets smaller. + lastChainsLength := len(chains) + for i := 0; i < len(chains)-1; i++ { + for j := 1; j < len(chains); j++ { + if i == j { + continue + } + + // attempt to add both chain together + chain, ok := chains[i].tryMergeChain(chains[j]) + if ok { + // If adding the chains together was successful, remove inner chain from + // list + chains = append(chains[:j], chains[j+1:]...) + } + + chains[i] = chain + } + } + + // If no chains were merged in this pass, the chain can never be built as a + // single list. Error. + if lastChainsLength == len(chains) { + return PEMBundle{}, errors.NewInvalidData("certificate chain is malformed or broken") + } + } + + // There is only a single chain left at index 0. Return chain as PEM. + return chains[0].toBundleAndCA() +} + +// toBundleAndCA will return the PEM bundle of this chain. +func (c *chainNode) toBundleAndCA() (PEMBundle, error) { + var ( + certs []*x509.Certificate + ca *x509.Certificate + ) + + for { + // If the issuer is nil, we have hit the root of the chain. Assign the CA + // to this certificate and stop traversing. + if c.issuer == nil { + ca = c.cert + break + } + + // Add this node's certificate to the list at the end. Ready to check + // next node up. + certs = append(certs, c.cert) + c = c.issuer + } + + caPEM, err := EncodeX509(ca) + if err != nil { + return PEMBundle{}, err + } + + // If no certificates parsed, then CA is the only certificate and should be + // the chain + if len(certs) == 0 { + return PEMBundle{ChainPEM: caPEM}, nil + } + + // Encode full certificate chain + chainPEM, err := EncodeX509Chain(certs) + if err != nil { + return PEMBundle{}, err + } + + // Return chain and ca + return PEMBundle{CAPEM: caPEM, ChainPEM: chainPEM}, nil +} + +// tryMergeChain glues two chains A and B together by adding one on top of +// the other. The function tries both gluing A on top of B and B on top of +// A, which is why the argument order for the two input chains does not +// matter. +// +// Gluability: We say that the chains A and B are glueable when either the +// leaf certificate of A can be verified using the root certificate of B, +// or that the leaf certificate of B can be verified using the root certificate +// of A. +// +// A leaf certificate C (as in "child") is verified by a certificate P +// (as in "parent"), when they satisfy C.CheckSignatureFrom(P). In the +// following diagram, C.CheckSignatureFrom(P) is satisfied, i.e., the +// signature ("sig") on the certificate C can be verified using the parent P: +// +// head tail +// +------+-------+ +------+-------+ +------+-------+ +// | | | | | | | | | +// | | sig ------->| C | sig ------->| P | | +// | | | | | | | | | +// +------+-------+ +------+-------+ +------+-------+ +// leaf certificate root certificate +// +// The function returns false if the chains A and B are not gluable. +func (c *chainNode) tryMergeChain(chain *chainNode) (*chainNode, bool) { + // The given chain's root has been signed by this node. Add this node on top + // of the given chain. + if chain.root().cert.CheckSignatureFrom(c.cert) == nil { + chain.root().issuer = c + return chain, true + } + + // The given chain is the issuer of the root of this node. Add the given + // chain on top of the root of this node. + if c.root().cert.CheckSignatureFrom(chain.cert) == nil { + c.root().issuer = chain + return c, true + } + + // Chains cannot be added together. + return c, false +} + +// Return the root most node of this chain. +func (c *chainNode) root() *chainNode { + for c.issuer != nil { + c = c.issuer + } + + return c +} diff --git a/pkg/util/pki/parse_test.go b/pkg/util/pki/parse_test.go index 110661f24..e106dcc19 100644 --- a/pkg/util/pki/parse_test.go +++ b/pkg/util/pki/parse_test.go @@ -17,11 +17,17 @@ limitations under the License. package pki import ( + "crypto" "crypto/ecdsa" + "crypto/rand" "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" "encoding/pem" + "reflect" "strings" "testing" + "time" v1 "github.com/jetstack/cert-manager/pkg/apis/certmanager/v1" ) @@ -174,3 +180,144 @@ func TestDecodePrivateKeyBytes(t *testing.T) { t.Run(test.name, testFn(test)) } } + +type testBundle struct { + cert *x509.Certificate + pem []byte + pk crypto.PrivateKey +} + +func mustCreateBundle(t *testing.T, issuer *testBundle, name string) *testBundle { + pk, err := GenerateRSAPrivateKey(2048) + if err != nil { + t.Fatal(err) + } + + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatal(err) + } + + template := &x509.Certificate{ + Version: 3, + BasicConstraintsValid: true, + SerialNumber: serialNumber, + PublicKeyAlgorithm: x509.RSA, + PublicKey: pk.Public(), + IsCA: true, + Subject: pkix.Name{ + CommonName: name, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + } + + var ( + issuerKey crypto.PrivateKey + issuerCert *x509.Certificate + ) + + if issuer == nil { + // Selfsigned (no issuer) + issuerKey = pk + issuerCert = template + } else { + issuerKey = issuer.pk + issuerCert = issuer.cert + } + + certpem, cert, err := SignCertificate(template, issuerCert, pk.Public(), issuerKey) + if err != nil { + t.Fatal(err) + } + + return &testBundle{pem: certpem, cert: cert, pk: pk} +} + +func TestParseCertificateChain(t *testing.T) { + root := mustCreateBundle(t, nil, "root") + int1 := mustCreateBundle(t, root, "int-1") + int2 := mustCreateBundle(t, int1, "int-2") + leaf := mustCreateBundle(t, int2, "leaf") + random := mustCreateBundle(t, nil, "random") + + joinPEM := func(first []byte, rest ...[]byte) []byte { + for _, b := range rest { + first = append(first, b...) + } + return first + } + + tests := map[string]struct { + inputBundle []byte + expPEMBundle PEMBundle + expErr bool + }{ + "if single certificate passed, return single certificate": { + inputBundle: root.pem, + expPEMBundle: PEMBundle{ChainPEM: root.pem}, + expErr: false, + }, + "if two certificate chain passed in order, should return single ca and certificate": { + inputBundle: joinPEM(int1.pem, root.pem), + expPEMBundle: PEMBundle{ChainPEM: int1.pem, CAPEM: root.pem}, + expErr: false, + }, + "if two certificate chain passed out of order, should return single ca and certificate": { + inputBundle: joinPEM(root.pem, int1.pem), + expPEMBundle: PEMBundle{ChainPEM: int1.pem, CAPEM: root.pem}, + expErr: false, + }, + "if 3 certificate chain passed out of order, should return single ca and chain in order": { + inputBundle: joinPEM(root.pem, int2.pem, int1.pem), + expPEMBundle: PEMBundle{ChainPEM: joinPEM(int2.pem, int1.pem), CAPEM: root.pem}, + expErr: false, + }, + "empty entries should be ignored, and return ca and certificate": { + inputBundle: joinPEM(root.pem, int2.pem, []byte("\n#foo\n \n"), int1.pem), + expPEMBundle: PEMBundle{ChainPEM: joinPEM(int2.pem, int1.pem), CAPEM: root.pem}, + expErr: false, + }, + "if 4 certificate chain passed in order, should return single ca and chain in order": { + inputBundle: joinPEM(leaf.pem, int1.pem, int2.pem, root.pem), + expPEMBundle: PEMBundle{ChainPEM: joinPEM(leaf.pem, int2.pem, int1.pem), CAPEM: root.pem}, + expErr: false, + }, + "if 4 certificate chain passed out of order, should return single ca and chain in order": { + inputBundle: joinPEM(root.pem, int1.pem, leaf.pem, int2.pem), + expPEMBundle: PEMBundle{ChainPEM: joinPEM(leaf.pem, int2.pem, int1.pem), CAPEM: root.pem}, + expErr: false, + }, + "if 3 certificate chain but has break in the chain, should return error": { + inputBundle: joinPEM(root.pem, int1.pem, leaf.pem), + expPEMBundle: PEMBundle{}, + expErr: true, + }, + "if 4 certificate chain but also random certificate, should return error": { + inputBundle: joinPEM(root.pem, int1.pem, leaf.pem, int2.pem, random.pem), + expPEMBundle: PEMBundle{}, + expErr: true, + }, + "if 6 certificate chain but some are duplicates, duplicates should be removed and return single ca with chain": { + inputBundle: joinPEM(int2.pem, int1.pem, root.pem, leaf.pem, int1.pem, root.pem), + expPEMBundle: PEMBundle{ChainPEM: joinPEM(leaf.pem, int2.pem, int1.pem), CAPEM: root.pem}, + expErr: false, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + bundle, err := ParseCertificateChainPEM(test.inputBundle) + if (err != nil) != test.expErr { + t.Errorf("unexpected error, exp=%t got=%v", + test.expErr, err) + } + + if !reflect.DeepEqual(bundle, test.expPEMBundle) { + t.Errorf("unexpected pem bundle, exp=%+s got=%+s", + test.expPEMBundle, bundle) + } + }) + } +}