diff --git a/sm4/sm4.go b/sm4/sm4.go index 0e301deb..04c78388 100644 --- a/sm4/sm4.go +++ b/sm4/sm4.go @@ -27,7 +27,7 @@ import ( const BlockSize = 16 -var IV = make([]byte, BlockSize) +var errIVLen = errors.New("SM4: invalid iv size") type SM4Key []byte @@ -291,26 +291,21 @@ func pkcs7UnPadding(src []byte) ([]byte, error) { return src[:(length - unpadding)], nil } -func SetIV(iv []byte) error { - if len(iv) != BlockSize { - return errors.New("SM4: invalid iv size") - } - IV = iv - return nil -} -func Sm4Cbc(key []byte, in []byte, mode bool) (out []byte, err error) { +func Sm4Cbc(key []byte, iv []byte, in []byte, mode bool) (out []byte, err error) { if len(key) != BlockSize { return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) } + if len(iv) != BlockSize { + return nil, errIVLen + } + var inData []byte if mode { inData = pkcs7Padding(in) } else { inData = in } - iv := make([]byte, BlockSize) - copy(iv, IV) out = make([]byte, len(inData)) c, err := NewCipher(key) if err != nil { @@ -376,10 +371,14 @@ func Sm4Ecb(key []byte, in []byte, mode bool) (out []byte, err error) { //密码反馈模式(Cipher FeedBack (CFB)) //https://blog.csdn.net/zy_strive_2012/article/details/102520356 //https://blog.csdn.net/sinat_23338865/article/details/72869841 -func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) { +func Sm4CFB(key []byte, iv []byte, in []byte, mode bool) (out []byte, err error) { if len(key) != BlockSize { return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) } + if len(iv) != BlockSize { + return nil, errIVLen + } + var inData []byte if mode { inData = pkcs7Padding(in) @@ -399,7 +398,7 @@ func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) { if mode { //加密 for i := 0; i < len(inData)/16; i++ { if i == 0 { - c.Encrypt(K, IV) + c.Encrypt(K, iv) cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) copy(out[i*16:i*16+16], cipherBlock) //copy(cipherBlock,out_tmp) @@ -415,7 +414,7 @@ func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) { var i int = 0 for ; i < len(inData)/16; i++ { if i == 0 { - c.Encrypt(K, IV) //这里是加密,而不是调用解密方法Decrypt + c.Encrypt(K, iv) //这里是加密,而不是调用解密方法Decrypt plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 copy(out[i*16:i*16+16], plainBlock) continue @@ -435,10 +434,14 @@ func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) { //输出反馈模式(Output feedback, OFB) //https://blog.csdn.net/chengqiuming/article/details/82390910 //https://blog.csdn.net/sinat_23338865/article/details/72869841 -func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) { +func Sm4OFB(key []byte, iv []byte, in []byte, mode bool) (out []byte, err error) { if len(key) != BlockSize { return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key))) } + if len(iv) != BlockSize { + return nil, errors.New("SM4: invalid iv size") + } + var inData []byte if mode { inData = pkcs7Padding(in) @@ -459,7 +462,7 @@ func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) { if mode { //加密 for i := 0; i < len(inData)/16; i++ { if i == 0 { - c.Encrypt(K, IV) + c.Encrypt(K, iv) cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) copy(out[i*16:i*16+16], cipherBlock) copy(shiftIV, K[:BlockSize]) @@ -474,14 +477,14 @@ func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) { } else { //解密 for i := 0; i < len(inData)/16; i++ { if i == 0 { - c.Encrypt(K, IV) //这里是加密,而不是调用解密方法Decrypt - plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 + c.Encrypt(K, iv) // 这里是加密,而不是调用解密方法Decrypt + plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) // 获取明文分组 copy(out[i*16:i*16+16], plainBlock) copy(shiftIV, K[:BlockSize]) continue } c.Encrypt(K, shiftIV) - plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组 + plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) // 获取明文分组 copy(out[i*16:i*16+16], plainBlock) copy(shiftIV, K[:BlockSize]) } diff --git a/sm4/sm4_test.go b/sm4/sm4_test.go index 6115ba14..cb9bd0c3 100644 --- a/sm4/sm4_test.go +++ b/sm4/sm4_test.go @@ -43,8 +43,6 @@ func TestSM4(t *testing.T) { } fmt.Printf("ecbMsg = %x\n", ecbMsg) iv := []byte("0000000000000000") - err = SetIV(iv) - fmt.Printf("err = %v\n", err) ecbDec, err := Sm4Ecb(key, ecbMsg, false) if err != nil { t.Errorf("sm4 dec error:%s", err) @@ -54,12 +52,12 @@ func TestSM4(t *testing.T) { if !testCompare(data, ecbDec) { t.Errorf("sm4 self enc and dec failed") } - cbcMsg, err := Sm4Cbc(key, data, true) + cbcMsg, err := Sm4Cbc(key, iv, data, true) if err != nil { t.Errorf("sm4 enc error:%s", err) } fmt.Printf("cbcMsg = %x\n", cbcMsg) - cbcDec, err := Sm4Cbc(key, cbcMsg, false) + cbcDec, err := Sm4Cbc(key, iv, cbcMsg, false) if err != nil { t.Errorf("sm4 dec error:%s", err) return @@ -69,26 +67,26 @@ func TestSM4(t *testing.T) { t.Errorf("sm4 self enc and dec failed") } - cbcMsg, err = Sm4CFB(key, data, true) + cbcMsg, err = Sm4CFB(key, iv, data, true) if err != nil { t.Errorf("sm4 enc error:%s", err) } fmt.Printf("cbcCFB = %x\n", cbcMsg) - cbcCfb, err := Sm4CFB(key, cbcMsg, false) + cbcCfb, err := Sm4CFB(key, iv, cbcMsg, false) if err != nil { t.Errorf("sm4 dec error:%s", err) return } fmt.Printf("cbcCFB = %x\n", cbcCfb) - cbcMsg, err = Sm4OFB(key, data, true) + cbcMsg, err = Sm4OFB(key, iv, data, true) if err != nil { t.Errorf("sm4 enc error:%s", err) } fmt.Printf("cbcOFB = %x\n", cbcMsg) - cbcOfc, err := Sm4OFB(key, cbcMsg, false) + cbcOfc, err := Sm4OFB(key, iv, cbcMsg, false) if err != nil { t.Errorf("sm4 dec error:%s", err) return