diff --git a/smartcontract/service/native/auth/auth.go b/smartcontract/service/native/auth/auth.go index a3a55c70a2..6a79de0690 100644 --- a/smartcontract/service/native/auth/auth.go +++ b/smartcontract/service/native/auth/auth.go @@ -21,7 +21,6 @@ package auth import ( "bytes" "fmt" - "strings" "time" "github.com/ontio/ontology/account" @@ -191,13 +190,12 @@ func AssignFuncsToRole(native *native.NativeService) ([]byte, error) { if err != nil { return nil, fmt.Errorf("[assignFuncsToRole] getRoleFunc failed: %v", err) } - if funcs != nil { - funcNames := append(funcs.funcNames, param.FuncNames...) - funcs.funcNames = stringSliceUniq(funcNames) - } else { + if funcs == nil { funcs = new(roleFuncs) - funcs.funcNames = stringSliceUniq(param.FuncNames) } + + funcs.AppendFuncs(param.FuncNames) + err = putRoleFunc(native, param.ContractAddr, param.Role, funcs) if err != nil { return nil, fmt.Errorf("[assignFuncsToRole] putRoleFunc failed: %v", err) @@ -589,10 +587,8 @@ func verifyToken(native *native.NativeService, contractAddr common.Address, call if funcs == nil || token.expireTime < native.Time { continue } - for _, f := range funcs.funcNames { - if strings.Compare(fn, f) == 0 { - return true, nil - } + if funcs.ContainsFunc(fn) { + return true, nil } } } @@ -610,10 +606,8 @@ func verifyToken(native *native.NativeService, contractAddr common.Address, call if funcs == nil || s.expireTime < native.Time { continue } - for _, f := range funcs.funcNames { - if strings.Compare(fn, f) == 0 { - return true, nil - } + if funcs.ContainsFunc(fn) { + return true, nil } } } diff --git a/smartcontract/service/native/auth/state.go b/smartcontract/service/native/auth/state.go index a2644a65d9..473ad7465c 100644 --- a/smartcontract/service/native/auth/state.go +++ b/smartcontract/service/native/auth/state.go @@ -20,6 +20,7 @@ package auth import ( "io" + "strings" "github.com/ontio/ontology/common/serialization" ) @@ -31,10 +32,26 @@ type roleFuncs struct { funcNames []string } +func (this *roleFuncs) AppendFuncs(fns []string) { + funcNames := append(this.funcNames, fns...) + this.funcNames = StringsDedupAndSort(funcNames) +} + +func (this *roleFuncs) ContainsFunc(fn string) bool { + for _, f := range this.funcNames { + if strings.Compare(fn, f) == 0 { + return true + } + } + + return false +} + func (this *roleFuncs) Serialize(w io.Writer) error { if err := serialization.WriteUint32(w, uint32(len(this.funcNames))); err != nil { return err } + this.funcNames = StringsDedupAndSort(this.funcNames) for _, fn := range this.funcNames { if err := serialization.WriteString(w, fn); err != nil { return err @@ -49,14 +66,17 @@ func (this *roleFuncs) Deserialize(rd io.Reader) error { if err != nil { return err } - this.funcNames = make([]string, 0) + funcNames := make([]string, 0) for i := uint32(0); i < fnLen; i++ { fn, err := serialization.ReadString(rd) if err != nil { return err } - this.funcNames = append(this.funcNames, fn) + funcNames = append(funcNames, fn) } + + this.funcNames = StringsDedupAndSort(funcNames) + return nil } diff --git a/smartcontract/service/native/auth/utils.go b/smartcontract/service/native/auth/utils.go index 8d069232da..5e44087a5b 100644 --- a/smartcontract/service/native/auth/utils.go +++ b/smartcontract/service/native/auth/utils.go @@ -22,6 +22,7 @@ import ( "bytes" "fmt" "io" + "sort" "github.com/ontio/ontology/common" "github.com/ontio/ontology/common/serialization" @@ -181,8 +182,8 @@ func putDelegateStatus(native *native.NativeService, contractAddr common.Address return nil } -//remote duplicates in the slice of string -func stringSliceUniq(s []string) []string { +//remove duplicates in the slice of string and sorts the slice in increasing order. +func StringsDedupAndSort(s []string) []string { smap := make(map[string]int) for i, str := range s { if str == "" { @@ -192,10 +193,11 @@ func stringSliceUniq(s []string) []string { } ret := make([]string, len(smap)) i := 0 - for str, _ := range smap { + for str := range smap { ret[i] = str i++ } + sort.Strings(ret) return ret } diff --git a/smartcontract/service/native/auth/utils_test.go b/smartcontract/service/native/auth/utils_test.go index 45d20143f7..6478296226 100644 --- a/smartcontract/service/native/auth/utils_test.go +++ b/smartcontract/service/native/auth/utils_test.go @@ -17,37 +17,13 @@ */ package auth -import "testing" +import ( + "github.com/magiconair/properties/assert" + "testing" +) -//{"a", "b"} == {"b", "a"} -func testEq(a, b []string) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - return false - } - - if len(a) != len(b) { - return false - } - Map := make(map[string]bool) - for i := range a { - Map[a[i]] = true - } - for _, s := range b { - _, ok := Map[s] - if !ok { - return false - } - } - return true -} func TestStringSliceUniq(t *testing.T) { - s := []string{"foo", "foo1", "foo2", "foo", "foo1", "foo2", "foo3"} - ret := stringSliceUniq(s) - t.Log(ret) - if !testEq(ret, []string{"foo", "foo1", "foo2", "foo3"}) { - t.Fatalf("failed") - } + s := []string{"foo3", "foo", "foo1", "foo2", "foo", "foo1", "foo2", "foo3"} + ret := StringsDedupAndSort(s) + assert.Equal(t, ret, []string{"foo", "foo1", "foo2", "foo3"}) }