add tests for the improvements made in #6561

Signed-off-by: Tim Ramlot <42113979+inteon@users.noreply.github.com>
This commit is contained in:
Tim Ramlot 2024-01-03 17:25:15 +01:00
parent e9a4793ba4
commit 9547fbdf94
No known key found for this signature in database
GPG Key ID: 47428728E0C2878D

View File

@ -21,6 +21,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"fmt"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -98,10 +99,28 @@ func TestParseSingleCertificateChain(t *testing.T) {
leafInterCN := mustCreateBundle(t, intA2, intA2.cert.Subject.CommonName) leafInterCN := mustCreateBundle(t, intA2, intA2.cert.Subject.CommonName)
random := mustCreateBundle(t, nil, "random") random := mustCreateBundle(t, nil, "random")
var thousandCertBundle PEMBundle
{
root := mustCreateBundle(t, nil, "root")
thousandCertBundle.CAPEM = root.pem
cert := root
var pems [][]byte
for i := 0; i < 999; i++ {
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
pems = append(pems, cert.pem)
}
for i := len(pems) - 1; i >= 0; i-- {
thousandCertBundle.ChainPEM = joinPEM(thousandCertBundle.ChainPEM, pems[i])
}
}
tests := map[string]struct { tests := map[string]struct {
inputBundle []byte inputBundle []byte
expPEMBundle PEMBundle expPEMBundle PEMBundle
expErr bool expErr bool
expErrString string
}{ }{
"if two certificate chain passed in order, should return single ca and certificate": { "if two certificate chain passed in order, should return single ca and certificate": {
inputBundle: joinPEM(intA1.pem, root.pem), inputBundle: joinPEM(intA1.pem, root.pem),
@ -148,11 +167,13 @@ func TestParseSingleCertificateChain(t *testing.T) {
inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem), inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem),
expPEMBundle: PEMBundle{}, expPEMBundle: PEMBundle{},
expErr: true, expErr: true,
expErrString: "certificate chain is malformed or broken",
}, },
"if 4 certificate chain but also random certificate, should return error": { "if 4 certificate chain but also random certificate, should return error": {
inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem, intA2.pem, random.pem), inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem, intA2.pem, random.pem),
expPEMBundle: PEMBundle{}, expPEMBundle: PEMBundle{},
expErr: true, expErr: true,
expErrString: "certificate chain is malformed or broken",
}, },
"if 6 certificate chain but some are duplicates, duplicates should be removed and return single ca with chain": { "if 6 certificate chain but some are duplicates, duplicates should be removed and return single ca with chain": {
inputBundle: joinPEM(intA2.pem, intA1.pem, root.pem, leaf.pem, intA1.pem, root.pem), inputBundle: joinPEM(intA2.pem, intA1.pem, root.pem, leaf.pem, intA1.pem, root.pem),
@ -168,6 +189,7 @@ func TestParseSingleCertificateChain(t *testing.T) {
inputBundle: joinPEM(root.pem, intA1.pem, intA2.pem, intB1.pem, intB2.pem), inputBundle: joinPEM(root.pem, intA1.pem, intA2.pem, intB1.pem, intB2.pem),
expPEMBundle: PEMBundle{}, expPEMBundle: PEMBundle{},
expErr: true, expErr: true,
expErrString: "certificate chain is malformed or broken",
}, },
"if certificate chain does not have a root ca, should append all intermediates to ChainPEM and use the root-most cert as CAPEM": { "if certificate chain does not have a root ca, should append all intermediates to ChainPEM and use the root-most cert as CAPEM": {
inputBundle: joinPEM(intA1.pem, intA2.pem, leaf.pem), inputBundle: joinPEM(intA1.pem, intA2.pem, leaf.pem),
@ -189,16 +211,65 @@ func TestParseSingleCertificateChain(t *testing.T) {
expPEMBundle: PEMBundle{ChainPEM: joinPEM(root.pem), CAPEM: root.pem}, expPEMBundle: PEMBundle{ChainPEM: joinPEM(root.pem), CAPEM: root.pem},
expErr: false, expErr: false,
}, },
"if long chain is passed (<= 1000 certs), a result should be returned quickly": {
inputBundle: joinPEM(thousandCertBundle.ChainPEM, thousandCertBundle.CAPEM),
expPEMBundle: thousandCertBundle,
expErr: false,
},
"if very long chain is passed (> 1000 certs), should error without DoS (1)": {
inputBundle: func() []byte {
root := mustCreateBundle(t, nil, "root")
cert := root
var chain []byte
for i := 0; i < 1001; i++ {
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
chain = joinPEM(chain, cert.pem)
}
return chain
}(),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is too long, must be less than 1000 certificates",
},
"if very long chain is passed (> 1000 certs), should error without DoS (2)": {
inputBundle: func() []byte {
root := mustCreateBundle(t, nil, "root")
cert := root
var chain []byte
for i := 0; i < 10000; i++ {
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
chain = joinPEM(chain, cert.pem)
}
return chain
}(),
expPEMBundle: PEMBundle{},
expErr: true,
expErrString: "certificate chain is too long, must be less than 1000 certificates",
},
} }
for name, test := range tests { for name, test := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
startTime := time.Now()
bundle, err := ParseSingleCertificateChainPEM(test.inputBundle) bundle, err := ParseSingleCertificateChainPEM(test.inputBundle)
if (err != nil) != test.expErr { if (err != nil) != test.expErr {
t.Errorf("unexpected error, exp=%t got=%v", t.Errorf("unexpected error, exp=%t got=%v",
test.expErr, err) test.expErr, err)
} }
if time.Since(startTime) > time.Second {
t.Errorf("ParseSingleCertificateChainPEM took too long to complete, input could cause DoS")
}
if err != nil && err.Error() != test.expErrString {
t.Errorf("unexpected error string, exp=%s got=%s",
test.expErrString, err.Error())
}
if !reflect.DeepEqual(bundle, test.expPEMBundle) { if !reflect.DeepEqual(bundle, test.expPEMBundle) {
t.Errorf("unexpected pem bundle, exp=%+s got=%+s", t.Errorf("unexpected pem bundle, exp=%+s got=%+s",
test.expPEMBundle, bundle) test.expPEMBundle, bundle)