add tests for the improvements made in #6561

Signed-off-by: Tim Ramlot <42113979+inteon@users.noreply.github.com>
This commit is contained in:
Tim Ramlot 2024-01-03 17:25:15 +01:00
parent e9a4793ba4
commit 9547fbdf94
No known key found for this signature in database
GPG Key ID: 47428728E0C2878D

View File

@ -21,6 +21,7 @@ import (
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"reflect"
"testing"
"time"
@ -98,10 +99,28 @@ func TestParseSingleCertificateChain(t *testing.T) {
leafInterCN := mustCreateBundle(t, intA2, intA2.cert.Subject.CommonName)
random := mustCreateBundle(t, nil, "random")
var thousandCertBundle PEMBundle
{
root := mustCreateBundle(t, nil, "root")
thousandCertBundle.CAPEM = root.pem
cert := root
var pems [][]byte
for i := 0; i < 999; i++ {
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
pems = append(pems, cert.pem)
}
for i := len(pems) - 1; i >= 0; i-- {
thousandCertBundle.ChainPEM = joinPEM(thousandCertBundle.ChainPEM, pems[i])
}
}
tests := map[string]struct {
inputBundle []byte
expPEMBundle PEMBundle
expErr bool
expErrString string
}{
"if two certificate chain passed in order, should return single ca and certificate": {
inputBundle: joinPEM(intA1.pem, root.pem),
@ -148,11 +167,13 @@ func TestParseSingleCertificateChain(t *testing.T) {
inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is malformed or broken",
},
"if 4 certificate chain but also random certificate, should return error": {
inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem, intA2.pem, random.pem),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is malformed or broken",
},
"if 6 certificate chain but some are duplicates, duplicates should be removed and return single ca with chain": {
inputBundle: joinPEM(intA2.pem, intA1.pem, root.pem, leaf.pem, intA1.pem, root.pem),
@ -168,6 +189,7 @@ func TestParseSingleCertificateChain(t *testing.T) {
inputBundle: joinPEM(root.pem, intA1.pem, intA2.pem, intB1.pem, intB2.pem),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is malformed or broken",
},
"if certificate chain does not have a root ca, should append all intermediates to ChainPEM and use the root-most cert as CAPEM": {
inputBundle: joinPEM(intA1.pem, intA2.pem, leaf.pem),
@ -189,16 +211,65 @@ func TestParseSingleCertificateChain(t *testing.T) {
expPEMBundle: PEMBundle{ChainPEM: joinPEM(root.pem), CAPEM: root.pem},
expErr: false,
},
"if long chain is passed (<= 1000 certs), a result should be returned quickly": {
inputBundle: joinPEM(thousandCertBundle.ChainPEM, thousandCertBundle.CAPEM),
expPEMBundle: thousandCertBundle,
expErr: false,
},
"if very long chain is passed (> 1000 certs), should error without DoS (1)": {
inputBundle: func() []byte {
root := mustCreateBundle(t, nil, "root")
cert := root
var chain []byte
for i := 0; i < 1001; i++ {
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
chain = joinPEM(chain, cert.pem)
}
return chain
}(),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is too long, must be less than 1000 certificates",
},
"if very long chain is passed (> 1000 certs), should error without DoS (2)": {
inputBundle: func() []byte {
root := mustCreateBundle(t, nil, "root")
cert := root
var chain []byte
for i := 0; i < 10000; i++ {
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
chain = joinPEM(chain, cert.pem)
}
return chain
}(),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is too long, must be less than 1000 certificates",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
startTime := time.Now()
bundle, err := ParseSingleCertificateChainPEM(test.inputBundle)
if (err != nil) != test.expErr {
t.Errorf("unexpected error, exp=%t got=%v",
test.expErr, err)
}
if time.Since(startTime) > time.Second {
t.Errorf("ParseSingleCertificateChainPEM took too long to complete, input could cause DoS")
}
if err != nil && err.Error() != test.expErrString {
t.Errorf("unexpected error string, exp=%s got=%s",
test.expErrString, err.Error())
}
if !reflect.DeepEqual(bundle, test.expPEMBundle) {
t.Errorf("unexpected pem bundle, exp=%+s got=%+s",
test.expPEMBundle, bundle)