Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sm4): using iv parameter instead of SetIV to avoid concurrency issue #129

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions sm4/sm4.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (

const BlockSize = 16

var IV = make([]byte, BlockSize)
var errIVLen = errors.New("SM4: invalid iv size")

type SM4Key []byte

Expand Down Expand Up @@ -291,26 +291,21 @@ func pkcs7UnPadding(src []byte) ([]byte, error) {

return src[:(length - unpadding)], nil
}
func SetIV(iv []byte) error {
if len(iv) != BlockSize {
return errors.New("SM4: invalid iv size")
}
IV = iv
return nil
}

func Sm4Cbc(key []byte, in []byte, mode bool) (out []byte, err error) {
func Sm4Cbc(key []byte, iv []byte, in []byte, mode bool) (out []byte, err error) {
if len(key) != BlockSize {
return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key)))
}
if len(iv) != BlockSize {
return nil, errIVLen
}

var inData []byte
if mode {
inData = pkcs7Padding(in)
} else {
inData = in
}
iv := make([]byte, BlockSize)
copy(iv, IV)
out = make([]byte, len(inData))
c, err := NewCipher(key)
if err != nil {
Expand Down Expand Up @@ -376,10 +371,14 @@ func Sm4Ecb(key []byte, in []byte, mode bool) (out []byte, err error) {
//密码反馈模式(Cipher FeedBack (CFB))
//https://blog.csdn.net/zy_strive_2012/article/details/102520356
//https://blog.csdn.net/sinat_23338865/article/details/72869841
func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) {
func Sm4CFB(key []byte, iv []byte, in []byte, mode bool) (out []byte, err error) {
if len(key) != BlockSize {
return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key)))
}
if len(iv) != BlockSize {
return nil, errIVLen
}

var inData []byte
if mode {
inData = pkcs7Padding(in)
Expand All @@ -399,7 +398,7 @@ func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) {
if mode { //加密
for i := 0; i < len(inData)/16; i++ {
if i == 0 {
c.Encrypt(K, IV)
c.Encrypt(K, iv)
cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16])
copy(out[i*16:i*16+16], cipherBlock)
//copy(cipherBlock,out_tmp)
Expand All @@ -415,7 +414,7 @@ func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) {
var i int = 0
for ; i < len(inData)/16; i++ {
if i == 0 {
c.Encrypt(K, IV) //这里是加密,而不是调用解密方法Decrypt
c.Encrypt(K, iv) //这里是加密,而不是调用解密方法Decrypt
plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组
copy(out[i*16:i*16+16], plainBlock)
continue
Expand All @@ -435,10 +434,14 @@ func Sm4CFB(key []byte, in []byte, mode bool) (out []byte, err error) {
//输出反馈模式(Output feedback, OFB)
//https://blog.csdn.net/chengqiuming/article/details/82390910
//https://blog.csdn.net/sinat_23338865/article/details/72869841
func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) {
func Sm4OFB(key []byte, iv []byte, in []byte, mode bool) (out []byte, err error) {
if len(key) != BlockSize {
return nil, errors.New("SM4: invalid key size " + strconv.Itoa(len(key)))
}
if len(iv) != BlockSize {
return nil, errors.New("SM4: invalid iv size")
}

var inData []byte
if mode {
inData = pkcs7Padding(in)
Expand All @@ -459,7 +462,7 @@ func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) {
if mode { //加密
for i := 0; i < len(inData)/16; i++ {
if i == 0 {
c.Encrypt(K, IV)
c.Encrypt(K, iv)
cipherBlock = xor(K[:BlockSize], inData[i*16:i*16+16])
copy(out[i*16:i*16+16], cipherBlock)
copy(shiftIV, K[:BlockSize])
Expand All @@ -474,14 +477,14 @@ func Sm4OFB(key []byte, in []byte, mode bool) (out []byte, err error) {
} else { //解密
for i := 0; i < len(inData)/16; i++ {
if i == 0 {
c.Encrypt(K, IV) //这里是加密,而不是调用解密方法Decrypt
plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组
c.Encrypt(K, iv) // 这里是加密,而不是调用解密方法Decrypt
plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) // 获取明文分组
copy(out[i*16:i*16+16], plainBlock)
copy(shiftIV, K[:BlockSize])
continue
}
c.Encrypt(K, shiftIV)
plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) //获取明文分组
plainBlock = xor(K[:BlockSize], inData[i*16:i*16+16]) // 获取明文分组
copy(out[i*16:i*16+16], plainBlock)
copy(shiftIV, K[:BlockSize])
}
Expand Down
14 changes: 6 additions & 8 deletions sm4/sm4_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ func TestSM4(t *testing.T) {
}
fmt.Printf("ecbMsg = %x\n", ecbMsg)
iv := []byte("0000000000000000")
err = SetIV(iv)
fmt.Printf("err = %v\n", err)
ecbDec, err := Sm4Ecb(key, ecbMsg, false)
if err != nil {
t.Errorf("sm4 dec error:%s", err)
Expand All @@ -54,12 +52,12 @@ func TestSM4(t *testing.T) {
if !testCompare(data, ecbDec) {
t.Errorf("sm4 self enc and dec failed")
}
cbcMsg, err := Sm4Cbc(key, data, true)
cbcMsg, err := Sm4Cbc(key, iv, data, true)
if err != nil {
t.Errorf("sm4 enc error:%s", err)
}
fmt.Printf("cbcMsg = %x\n", cbcMsg)
cbcDec, err := Sm4Cbc(key, cbcMsg, false)
cbcDec, err := Sm4Cbc(key, iv, cbcMsg, false)
if err != nil {
t.Errorf("sm4 dec error:%s", err)
return
Expand All @@ -69,26 +67,26 @@ func TestSM4(t *testing.T) {
t.Errorf("sm4 self enc and dec failed")
}

cbcMsg, err = Sm4CFB(key, data, true)
cbcMsg, err = Sm4CFB(key, iv, data, true)
if err != nil {
t.Errorf("sm4 enc error:%s", err)
}
fmt.Printf("cbcCFB = %x\n", cbcMsg)

cbcCfb, err := Sm4CFB(key, cbcMsg, false)
cbcCfb, err := Sm4CFB(key, iv, cbcMsg, false)
if err != nil {
t.Errorf("sm4 dec error:%s", err)
return
}
fmt.Printf("cbcCFB = %x\n", cbcCfb)

cbcMsg, err = Sm4OFB(key, data, true)
cbcMsg, err = Sm4OFB(key, iv, data, true)
if err != nil {
t.Errorf("sm4 enc error:%s", err)
}
fmt.Printf("cbcOFB = %x\n", cbcMsg)

cbcOfc, err := Sm4OFB(key, cbcMsg, false)
cbcOfc, err := Sm4OFB(key, iv, cbcMsg, false)
if err != nil {
t.Errorf("sm4 dec error:%s", err)
return
Expand Down