From fb9cefa7a3cce3572f2a3a83e37d7381a8dd86b9 Mon Sep 17 00:00:00 2001 From: Alver Lyu Date: Thu, 1 Aug 2019 12:08:34 +0800 Subject: [PATCH] add checks in sm2 encryption --- sm2/encryption.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sm2/encryption.go b/sm2/encryption.go index 7bc83af..a3ec65f 100644 --- a/sm2/encryption.go +++ b/sm2/encryption.go @@ -137,6 +137,9 @@ func Encrypt(pub *ecdsa.PublicKey, data []byte) ([]byte, error) { copy(encryptData[32-len(x1.Bytes()):], x1.Bytes()) copy(encryptData[64-len(y1.Bytes()):], y1.Bytes()) + if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { + return nil, errors.New("invalid public key") + } x2, y2 = c.ScalarMult(pub.X, pub.Y, k.Bytes()) x2y2 := make([]byte, 64) copy(x2y2[32-len(x2.Bytes()):], x2.Bytes()) @@ -175,6 +178,12 @@ func Decrypt(priv *ecdsa.PrivateKey, encryptData []byte) ([]byte, error) { x1 := new(big.Int).SetBytes(encryptData[:32]) y1 := new(big.Int).SetBytes(encryptData[32:64]) + if x1.Sign() == 0 && y1.Sign() == 0 { + return nil, errors.New("C1 is infinity") + } + if !c.IsOnCurve(x1, y1) { + return nil, errors.New("C1 is not on curve") + } x2, y2 := c.ScalarMult(x1, y1, priv.D.Bytes()) c2 := make([]byte, 64)