From 9547fbdf945436b9f64141a78407d66cd89eb1d3 Mon Sep 17 00:00:00 2001 From: Tim Ramlot <42113979+inteon@users.noreply.github.com> Date: Wed, 3 Jan 2024 17:25:15 +0100 Subject: [PATCH] add tests for the improvements made in #6561 Signed-off-by: Tim Ramlot <42113979+inteon@users.noreply.github.com> --- pkg/util/pki/parse_certificate_chain_test.go | 71 ++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/pkg/util/pki/parse_certificate_chain_test.go b/pkg/util/pki/parse_certificate_chain_test.go index b38f753c6..1807331c6 100644 --- a/pkg/util/pki/parse_certificate_chain_test.go +++ b/pkg/util/pki/parse_certificate_chain_test.go @@ -21,6 +21,7 @@ import ( "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "fmt" "reflect" "testing" "time" @@ -98,10 +99,28 @@ func TestParseSingleCertificateChain(t *testing.T) { leafInterCN := mustCreateBundle(t, intA2, intA2.cert.Subject.CommonName) random := mustCreateBundle(t, nil, "random") + var thousandCertBundle PEMBundle + { + root := mustCreateBundle(t, nil, "root") + thousandCertBundle.CAPEM = root.pem + + cert := root + var pems [][]byte + for i := 0; i < 999; i++ { + cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i)) + pems = append(pems, cert.pem) + } + + for i := len(pems) - 1; i >= 0; i-- { + thousandCertBundle.ChainPEM = joinPEM(thousandCertBundle.ChainPEM, pems[i]) + } + } + tests := map[string]struct { inputBundle []byte expPEMBundle PEMBundle expErr bool + expErrString string }{ "if two certificate chain passed in order, should return single ca and certificate": { inputBundle: joinPEM(intA1.pem, root.pem), @@ -148,11 +167,13 @@ func TestParseSingleCertificateChain(t *testing.T) { inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem), expPEMBundle: PEMBundle{}, expErr: true, + expErrString: "certificate chain is malformed or broken", }, "if 4 certificate chain but also random certificate, should return error": { inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem, intA2.pem, random.pem), expPEMBundle: PEMBundle{}, expErr: true, + expErrString: "certificate chain is malformed or broken", }, "if 6 certificate chain but some are duplicates, duplicates should be removed and return single ca with chain": { inputBundle: joinPEM(intA2.pem, intA1.pem, root.pem, leaf.pem, intA1.pem, root.pem), @@ -168,6 +189,7 @@ func TestParseSingleCertificateChain(t *testing.T) { inputBundle: joinPEM(root.pem, intA1.pem, intA2.pem, intB1.pem, intB2.pem), expPEMBundle: PEMBundle{}, expErr: true, + expErrString: "certificate chain is malformed or broken", }, "if certificate chain does not have a root ca, should append all intermediates to ChainPEM and use the root-most cert as CAPEM": { inputBundle: joinPEM(intA1.pem, intA2.pem, leaf.pem), @@ -189,16 +211,65 @@ func TestParseSingleCertificateChain(t *testing.T) { expPEMBundle: PEMBundle{ChainPEM: joinPEM(root.pem), CAPEM: root.pem}, expErr: false, }, + "if long chain is passed (<= 1000 certs), a result should be returned quickly": { + inputBundle: joinPEM(thousandCertBundle.ChainPEM, thousandCertBundle.CAPEM), + expPEMBundle: thousandCertBundle, + expErr: false, + }, + "if very long chain is passed (> 1000 certs), should error without DoS (1)": { + inputBundle: func() []byte { + root := mustCreateBundle(t, nil, "root") + + cert := root + var chain []byte + for i := 0; i < 1001; i++ { + cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i)) + chain = joinPEM(chain, cert.pem) + } + + return chain + }(), + expPEMBundle: PEMBundle{}, + expErr: true, + expErrString: "certificate chain is too long, must be less than 1000 certificates", + }, + "if very long chain is passed (> 1000 certs), should error without DoS (2)": { + inputBundle: func() []byte { + root := mustCreateBundle(t, nil, "root") + + cert := root + var chain []byte + for i := 0; i < 10000; i++ { + cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i)) + chain = joinPEM(chain, cert.pem) + } + + return chain + }(), + expPEMBundle: PEMBundle{}, + expErr: true, + expErrString: "certificate chain is too long, must be less than 1000 certificates", + }, } for name, test := range tests { t.Run(name, func(t *testing.T) { + startTime := time.Now() bundle, err := ParseSingleCertificateChainPEM(test.inputBundle) if (err != nil) != test.expErr { t.Errorf("unexpected error, exp=%t got=%v", test.expErr, err) } + if time.Since(startTime) > time.Second { + t.Errorf("ParseSingleCertificateChainPEM took too long to complete, input could cause DoS") + } + + if err != nil && err.Error() != test.expErrString { + t.Errorf("unexpected error string, exp=%s got=%s", + test.expErrString, err.Error()) + } + if !reflect.DeepEqual(bundle, test.expPEMBundle) { t.Errorf("unexpected pem bundle, exp=%+s got=%+s", test.expPEMBundle, bundle)