add tests for the improvements made in #6561
Signed-off-by: Tim Ramlot <42113979+inteon@users.noreply.github.com>
This commit is contained in:
parent
e9a4793ba4
commit
9547fbdf94
@ -21,6 +21,7 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
@ -98,10 +99,28 @@ func TestParseSingleCertificateChain(t *testing.T) {
|
||||
leafInterCN := mustCreateBundle(t, intA2, intA2.cert.Subject.CommonName)
|
||||
random := mustCreateBundle(t, nil, "random")
|
||||
|
||||
var thousandCertBundle PEMBundle
|
||||
{
|
||||
root := mustCreateBundle(t, nil, "root")
|
||||
thousandCertBundle.CAPEM = root.pem
|
||||
|
||||
cert := root
|
||||
var pems [][]byte
|
||||
for i := 0; i < 999; i++ {
|
||||
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
|
||||
pems = append(pems, cert.pem)
|
||||
}
|
||||
|
||||
for i := len(pems) - 1; i >= 0; i-- {
|
||||
thousandCertBundle.ChainPEM = joinPEM(thousandCertBundle.ChainPEM, pems[i])
|
||||
}
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
inputBundle []byte
|
||||
expPEMBundle PEMBundle
|
||||
expErr bool
|
||||
expErrString string
|
||||
}{
|
||||
"if two certificate chain passed in order, should return single ca and certificate": {
|
||||
inputBundle: joinPEM(intA1.pem, root.pem),
|
||||
@ -148,11 +167,13 @@ func TestParseSingleCertificateChain(t *testing.T) {
|
||||
inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem),
|
||||
expPEMBundle: PEMBundle{},
|
||||
expErr: true,
|
||||
expErrString: "certificate chain is malformed or broken",
|
||||
},
|
||||
"if 4 certificate chain but also random certificate, should return error": {
|
||||
inputBundle: joinPEM(root.pem, intA1.pem, leaf.pem, intA2.pem, random.pem),
|
||||
expPEMBundle: PEMBundle{},
|
||||
expErr: true,
|
||||
expErrString: "certificate chain is malformed or broken",
|
||||
},
|
||||
"if 6 certificate chain but some are duplicates, duplicates should be removed and return single ca with chain": {
|
||||
inputBundle: joinPEM(intA2.pem, intA1.pem, root.pem, leaf.pem, intA1.pem, root.pem),
|
||||
@ -168,6 +189,7 @@ func TestParseSingleCertificateChain(t *testing.T) {
|
||||
inputBundle: joinPEM(root.pem, intA1.pem, intA2.pem, intB1.pem, intB2.pem),
|
||||
expPEMBundle: PEMBundle{},
|
||||
expErr: true,
|
||||
expErrString: "certificate chain is malformed or broken",
|
||||
},
|
||||
"if certificate chain does not have a root ca, should append all intermediates to ChainPEM and use the root-most cert as CAPEM": {
|
||||
inputBundle: joinPEM(intA1.pem, intA2.pem, leaf.pem),
|
||||
@ -189,16 +211,65 @@ func TestParseSingleCertificateChain(t *testing.T) {
|
||||
expPEMBundle: PEMBundle{ChainPEM: joinPEM(root.pem), CAPEM: root.pem},
|
||||
expErr: false,
|
||||
},
|
||||
"if long chain is passed (<= 1000 certs), a result should be returned quickly": {
|
||||
inputBundle: joinPEM(thousandCertBundle.ChainPEM, thousandCertBundle.CAPEM),
|
||||
expPEMBundle: thousandCertBundle,
|
||||
expErr: false,
|
||||
},
|
||||
"if very long chain is passed (> 1000 certs), should error without DoS (1)": {
|
||||
inputBundle: func() []byte {
|
||||
root := mustCreateBundle(t, nil, "root")
|
||||
|
||||
cert := root
|
||||
var chain []byte
|
||||
for i := 0; i < 1001; i++ {
|
||||
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
|
||||
chain = joinPEM(chain, cert.pem)
|
||||
}
|
||||
|
||||
return chain
|
||||
}(),
|
||||
expPEMBundle: PEMBundle{},
|
||||
expErr: true,
|
||||
expErrString: "certificate chain is too long, must be less than 1000 certificates",
|
||||
},
|
||||
"if very long chain is passed (> 1000 certs), should error without DoS (2)": {
|
||||
inputBundle: func() []byte {
|
||||
root := mustCreateBundle(t, nil, "root")
|
||||
|
||||
cert := root
|
||||
var chain []byte
|
||||
for i := 0; i < 10000; i++ {
|
||||
cert = mustCreateBundle(t, cert, fmt.Sprintf("int-%d", i))
|
||||
chain = joinPEM(chain, cert.pem)
|
||||
}
|
||||
|
||||
return chain
|
||||
}(),
|
||||
expPEMBundle: PEMBundle{},
|
||||
expErr: true,
|
||||
expErrString: "certificate chain is too long, must be less than 1000 certificates",
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
startTime := time.Now()
|
||||
bundle, err := ParseSingleCertificateChainPEM(test.inputBundle)
|
||||
if (err != nil) != test.expErr {
|
||||
t.Errorf("unexpected error, exp=%t got=%v",
|
||||
test.expErr, err)
|
||||
}
|
||||
|
||||
if time.Since(startTime) > time.Second {
|
||||
t.Errorf("ParseSingleCertificateChainPEM took too long to complete, input could cause DoS")
|
||||
}
|
||||
|
||||
if err != nil && err.Error() != test.expErrString {
|
||||
t.Errorf("unexpected error string, exp=%s got=%s",
|
||||
test.expErrString, err.Error())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(bundle, test.expPEMBundle) {
|
||||
t.Errorf("unexpected pem bundle, exp=%+s got=%+s",
|
||||
test.expPEMBundle, bundle)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user