diff --git a/pkg/util/pki/generate.go b/pkg/util/pki/generate.go index 60d0a1e3e..e7f8fcfdf 100644 --- a/pkg/util/pki/generate.go +++ b/pkg/util/pki/generate.go @@ -174,70 +174,35 @@ func PublicKeyForPrivateKey(pk crypto.PrivateKey) (crypto.PublicKey, error) { } } -// PublicKeyMatchesCertificate can be used to verify the given public key -// is the correct counter-part to the given x509 Certificate. -// It will return false and no error if the public key is *not* valid for the -// given Certificate. -// It will return true if the public key *is* valid for the given Certificate. -// It will return an error if either of the passed parameters are of an -// unrecognised type (i.e. non RSA/ECDSA) +// PublicKeyMatchesCertificate checks whether the given public key matches the +// public key in the given x509.Certificate. +// Returns false and no error if the public key is *not* the same as the certificate's key +// Returns true and no error if the public key *is* the same as the certificate's key +// Returns an error if the certificate's key type cannot be determined (i.e. non RSA/ECDSA keys) func PublicKeyMatchesCertificate(check crypto.PublicKey, crt *x509.Certificate) (bool, error) { - switch pub := crt.PublicKey.(type) { - case *rsa.PublicKey: - rsaCheck, ok := check.(*rsa.PublicKey) - if !ok { - return false, nil - } - if pub.N.Cmp(rsaCheck.N) != 0 { - return false, nil - } - return true, nil - case *ecdsa.PublicKey: - ecdsaCheck, ok := check.(*ecdsa.PublicKey) - if !ok { - return false, nil - } - if pub.X.Cmp(ecdsaCheck.X) != 0 || pub.Y.Cmp(ecdsaCheck.Y) != 0 { - return false, nil - } - return true, nil - default: - return false, fmt.Errorf("unrecognised Certificate public key type") - } + return PublicKeysEqual(crt.PublicKey, check) } -// PublicKeyMatchesCSR can be used to verify the given public key is the correct -// counter-part to the given x509 CertificateRequest. -// It will return false and no error if the public key is *not* valid for the -// given CertificateRequest. -// It will return true if the public key *is* valid for the given CertificateRequest. -// It will return an error if either of the passed parameters are of an -// unrecognised type (i.e. non RSA/ECDSA) +// PublicKeyMatchesCSR can be used to verify the given public key matches the +// public key in the given x509.CertificateRequest. +// Returns false and no error if the given public key is *not* the same as the CSR's key +// Returns true and no error if the given public key *is* the same as the CSR's key +// Returns an error if the CSR's key type cannot be determined (i.e. non RSA/ECDSA keys) func PublicKeyMatchesCSR(check crypto.PublicKey, csr *x509.CertificateRequest) (bool, error) { - return PublicKeysEqual(check, csr.PublicKey) + return PublicKeysEqual(csr.PublicKey, check) } +// PublicKeysEqual compares two given public keys for equality. +// The definition of "equality" depends on the type of the public keys. +// Returns true if the keys are the same, false if they differ or an error if +// the key type of `a` cannot be determined. func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) { switch pub := a.(type) { case *rsa.PublicKey: - rsaCheck, ok := b.(*rsa.PublicKey) - if !ok { - return false, nil - } - if pub.N.Cmp(rsaCheck.N) != 0 { - return false, nil - } - return true, nil + return pub.Equal(b), nil case *ecdsa.PublicKey: - ecdsaCheck, ok := b.(*ecdsa.PublicKey) - if !ok { - return false, nil - } - if pub.X.Cmp(ecdsaCheck.X) != 0 || pub.Y.Cmp(ecdsaCheck.Y) != 0 { - return false, nil - } - return true, nil + return pub.Equal(b), nil default: - return false, fmt.Errorf("unrecognised public key type") + return false, fmt.Errorf("unrecognised public key type: %T", a) } } diff --git a/pkg/util/pki/generate_test.go b/pkg/util/pki/generate_test.go index 30640f60c..d1cb398f5 100644 --- a/pkg/util/pki/generate_test.go +++ b/pkg/util/pki/generate_test.go @@ -454,3 +454,109 @@ O7WnDn8nuLFdW+NzzbIrTw== t.Run(test.name, testFn(test)) } } + +func TestPublicKeysEqualECDSA(t *testing.T) { + key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("couldn't generate P256 key: %v", err) + } + + // note the different curve type + key2, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + t.Fatalf("couldn't generate P521 key: %v", err) + } + + pub1 := key1.Public().(*ecdsa.PublicKey) + + // (pub1.X, pub1.Y) isn't likely to be on the curve for key2, so pub2 will be + // invalid after changing its X and Y below; still, pub2 is useful for + // the test + + // this is not dissimilar to the standard library's test: + // https://github.com/golang/go/blob/14a18b7d2538232c6cd6937297c421d5f6b7d92f/src/crypto/ecdsa/equal_test.go#L55-L64 + pub2 := key2.Public().(*ecdsa.PublicKey) + pub2.X = pub1.X + pub2.Y = pub1.Y + + if pub1.Equal(pub2) { + t.Fatalf("invalid test: got a match from curves which should differ:\npub1: %#v\npub2: %#v\n", pub1, pub2) + } + + match, err := PublicKeysEqual(pub1, pub2) + if err != nil { + t.Fatalf("unexpected error from PublicKeysEqual: %v", err) + } + + if match { + t.Errorf("got an incorrect match from different curves:\npub1 type: %#v\npub2 type: %#v\n", pub1.Params().Name, pub2.Params().Name) + } +} + +const hardcodedTestKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAtaX7TB+WQ4yhTRLw6G0V8oLRaRzjJ1UiAh1Uw+K8SpgMehfm +WS6//y0iBwCfQWTC4Fw1sU99XA8yOIN+fMMdNcmPxvP7JKlUQRrM3RpXXD5eo+MZ +fJmc344Pn5f/aMoFvDq745YLTP3C5PJj1qcljch00FPORCCCFtdFkynzKoclZruW +cJpbFgt9mE/Qk2Xed8FZ+AESZmxAYVEhCv3JETpQfO3cW15+Hxug4eMQ6daAyYjT ++52QbENIXliKYCS+LrguQJsMNMveajoWcGMHbQBs+I2umlh0UDYRAZ3PbgO9GDfZ +tUoasBDM0SDjvtpiL+8UnbDlrwfZYgwGo/YixQIDAQABAoIBAQCIc0yYXEn2KBeq +3AWXswn/iAFiok6IZ00KpZndI98pcZo9xOJGL/YN64taEz+OUfCJtPqoXPvgQZIK +HczQT4kLtIOKghAv8/rUhRtLI9Rn+HoDRj8I+CN9UyutSPKVdtxkDwLA7R9EEINs +lCAnSJvPK7uEGtAhIQJXwhIDgEmnsWSKq3OTNRbKe4kF7bOAVsEj9KZjmHcWuhDV +LJ0x1+uWo18UFztHmQL/Vp0VJKYBo2tAql3LjHtGFI+uZ38X2HsAgIQSjVvXZyR1 +FvZx0XymoF8zYJzV2yfzgF9Toot4SWlKsUuX3w2FYj8hnQbn02x0m7sZmhNY25Fd +ljZCWOFpAoGBAOloRL0kKAA71lR6zwV7M+Xxozzk3u2x+GdTRCeBtnzjMQGtLs5/ +KsOROPNj2/wv+rH8FhAFiFKcwr4RrouVWxqCd6YwtAiRe92NfkuXPIDTq2G6K9ge +i5Z5yMImjeG1P4GvaQg9f1YF8oNO/EEtWihyN+VcLQHJFy714cMhurZDAoGBAMc7 +JjRmjtybZj3VzPTp0c2emce1fGHMtcmFX+dauLmGNDXCo/cJs4XYKlp6vhuBY4PR +IcbqVFBshk2cCC6IRTsAIcPLi4rwqe8uRhtHbyXN1lmcNqJq/l9Q8xoQf7l/ik4r +ttrSb7/I2hyEm2xJFTTXqbpx3AQqQQbPwl3sFuZXAoGBAKpv6UH0VQFWsHuf8ewe +uxb+DCU7O0521t0cgHgY0BkCDZcbz0Iaui90rBGOqeTNZFLzsWihoZoxvkLsxnhG +5+/DtXs1tUFMexada8vm89deuZbzS3DVXTjUVTTw0kou/+DDJf9OaN14Gk6oLqup +YlyGiyqA1JypKrSv99t1ldHhAoGAQBFKWOF+IX0rpMjjLwMd/8R36Vv4Uq706ogk +bg6jhq2cjok4FxIck/cOr6f3CHtUWChhd0kVsgMkMULy8pvJv45sTT1gc16vFwZH +bzBKktqdipWMkDBd+qLaelBB8pIMFNVD6Rxw6Tiawz71iB38XtDXeOhyezhnTtxy +wadROeMCgYEAlsfw3Gk5zftMFsPsfvFREvq3em+UC0jP5FLcz+LzTk1mEn+1SXtB +lFP7bcMXkRBh4tlk0gDLHvnwIomA+/dRnEIGBPl5nvZNF1HybUWxXTa3dW9Jw9V/ +3J9xMYH/v9uMSt0j5xhPcTrI6HYtrT5lZMZNOI5vbVo3D6KYLWtfgWA= +-----END RSA PRIVATE KEY----- +` + +func TestPublicKeysEqualRSA(t *testing.T) { + // parse a hardcoded key rather than generating since generating RSA keys + // is absurdly slow: + // BenchmarkGen-8 12 101415795 ns/op + // BenchmarkParse-8 48930 24361 ns/op + rawKey, err := DecodePrivateKeyBytes([]byte(hardcodedTestKey)) + if err != nil { + t.Fatalf("couldn't parse RSA test key: %v", err) + } + + key1 := rawKey.(*rsa.PrivateKey) + + pub1 := key1.Public().(*rsa.PublicKey) + + // changing E like this might mean the public key is invalid, but + // it should still be fine for testing our comparison function + pub2 := &rsa.PublicKey{} + *pub2 = *pub1 + + // 3 is valid because the exponent in hardcodedTestKey is 65535 + // if the test key changes, this could have to change. + // note that there are relatively few exponents actually used in the real world + // and as such this shouldn't just be a random value + pub2.E = 3 + + if pub1.Equal(pub2) { + t.Fatalf("invalid test: got a match from keys which should differ:\npub1: %#v\npub2: %#v\n", pub1, pub2) + } + + match, err := PublicKeysEqual(pub1, pub2) + if err != nil { + t.Fatalf("unexpected error from PublicKeysEqual: %v", err) + } + + if match { + t.Errorf("got an incorrect match from different RSA keys:\npub1: %#v\npub2: %#v\n", pub1, pub2) + } +}