From 5056d18dff90389dc3b0fefb172a60526d217267 Mon Sep 17 00:00:00 2001 From: Dxyinme <32793868+dxyinme@users.noreply.github.com> Date: Mon, 8 Jan 2024 22:08:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0Pair=E7=B1=BB=E5=9E=8B=20(#23?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 增加对Pair类型的支持 * 按照comments修改设计 * 按照review意见修改code * 修复可能出现的race --- mapx/map.go | 27 +++++ mapx/map_test.go | 61 +++++++++++ tuple/pair/pair.go | 132 ++++++++++++++++++++++ tuple/pair/pair_test.go | 237 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 457 insertions(+) create mode 100644 tuple/pair/pair.go create mode 100644 tuple/pair/pair_test.go diff --git a/mapx/map.go b/mapx/map.go index 0d3f8678..78e838ff 100644 --- a/mapx/map.go +++ b/mapx/map.go @@ -14,6 +14,8 @@ package mapx +import "fmt" + // Keys 返回 map 里面的所有的 key。 // 需要注意:这些 key 的顺序是随机。 func Keys[K comparable, V any](m map[K]V) []K { @@ -45,3 +47,28 @@ func KeysValues[K comparable, V any](m map[K]V) ([]K, []V) { } return keys, values } + +// ToMap 将会返回一个map[K]V +// 请保证传入的 keys 与 values 长度相同,长度均为n +// 长度不相同或者 keys 或者 values 为nil则会抛出异常 +// 返回的 m map[K]V 保证对于所有的 0 <= i < n +// m[keys[i]] = values[i] +// +// 注意: +// 如果传入的数组中存在 0 <= i < j < n使得 keys[i] == keys[j] +// 则在返回的 m 中 m[keys[i]] = values[j] +// 如果keys和values的长度为0,则会返回一个空map +func ToMap[K comparable, V any](keys []K, values []V) (m map[K]V, err error) { + if keys == nil || values == nil { + return nil, fmt.Errorf("keys与values均不可为nil") + } + n := len(keys) + if n != len(values) { + return nil, fmt.Errorf("keys与values的长度不同, len(keys)=%d, len(values)=%d", n, len(values)) + } + m = make(map[K]V, n) + for i := 0; i < n; i++ { + m[keys[i]] = values[i] + } + return +} diff --git a/mapx/map_test.go b/mapx/map_test.go index 569ac2a2..0d847013 100644 --- a/mapx/map_test.go +++ b/mapx/map_test.go @@ -15,6 +15,7 @@ package mapx import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -147,3 +148,63 @@ func TestKeysValues(t *testing.T) { }) } } + +func TestToMap(t *testing.T) { + type caseType struct { + keys []int + values []string + + result map[int]string + err error + } + for _, c := range []caseType{ + { + keys: []int{1, 2, 3}, + values: []string{"1", "2", "3"}, + result: map[int]string{ + 1: "1", + 2: "2", + 3: "3", + }, + err: nil, + }, + { + keys: []int{1, 2, 3}, + values: []string{"1", "2"}, + result: nil, + err: fmt.Errorf("keys与values的长度不同, len(keys)=3, len(values)=2"), + }, + { + keys: []int{1, 2, 3}, + values: nil, + result: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: nil, + values: []string{"1", "2"}, + result: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: nil, + values: nil, + result: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: []int{1, 2, 3, 1, 1}, + values: []string{"1", "2", "3", "10", "100"}, + result: map[int]string{ + 1: "100", + 2: "2", + 3: "3", + }, + err: nil, + }, + } { + result, err := ToMap(c.keys, c.values) + assert.Equal(t, c.err, err) + assert.Equal(t, c.result, result) + } +} diff --git a/tuple/pair/pair.go b/tuple/pair/pair.go new file mode 100644 index 00000000..b692d54f --- /dev/null +++ b/tuple/pair/pair.go @@ -0,0 +1,132 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pair + +import ( + "fmt" +) + +type Pair[K any, V any] struct { + Key K + Value V +} + +func (pair *Pair[K, V]) String() string { + return fmt.Sprintf("<%#v, %#v>", pair.Key, pair.Value) +} + +// Split 方法将Key, Value作为返回参数传出。 +func (pair *Pair[K, V]) Split() (K, V) { + return pair.Key, pair.Value +} + +func NewPair[K any, V any]( + key K, + value V, +) Pair[K, V] { + return Pair[K, V]{ + Key: key, + Value: value, + } +} + +// NewPairs 需要传入两个长度相同并且均不为nil的数组 keys 和 values, +// 设keys长度为n,返回一个长度为n的pair数组。 +// 保证: +// +// 返回的pair数组满足条件(设pair数组为p): +// 对于所有的 0 <= i < n +// p[i].Key == keys[i] 并且 p[i].Value == values[i] +// +// 如果传入的keys或者values为nil,会返回error +// +// 如果传入的keys长度与values长度不同,会返回error +func NewPairs[K any, V any]( + keys []K, + values []V, +) ([]Pair[K, V], error) { + if keys == nil || values == nil { + return nil, fmt.Errorf("keys与values均不可为nil") + } + n := len(keys) + if n != len(values) { + return nil, fmt.Errorf("keys与values的长度不同, len(keys)=%d, len(values)=%d", n, len(values)) + } + pairs := make([]Pair[K, V], n) + for i := 0; i < n; i++ { + pairs[i] = NewPair(keys[i], values[i]) + } + return pairs, nil +} + +// SplitPairs 需要传入一个[]Pair[K, V],数组可以为nil。 +// 设pairs数组的长度为n,返回两个长度均为n的数组keys, values。 +// 如果pairs数组是nil, 则返回的keys与values也均为nil。 +func SplitPairs[K any, V any](pairs []Pair[K, V]) (keys []K, values []V) { + if pairs == nil { + return nil, nil + } + n := len(pairs) + keys = make([]K, n) + values = make([]V, n) + for i, pair := range pairs { + keys[i], values[i] = pair.Split() + } + return +} + +// FlattenPairs 需要传入一个[]Pair[K, V],数组可以为nil +// 如果pairs数组为nil,则返回的flatPairs数组也为nil +// +// 设pairs数组长度为n,保证返回的flatPairs数组长度为2 * n且满足: +// 对于所有的 0 <= i < n +// flatPairs[i * 2] == pairs[i].Key +// flatPairs[i * 2 + 1] == pairs[i].Value +func FlattenPairs[K any, V any](pairs []Pair[K, V]) (flatPairs []any) { + if pairs == nil { + return nil + } + n := len(pairs) + flatPairs = make([]any, 0, n*2) + for _, pair := range pairs { + flatPairs = append(flatPairs, pair.Key, pair.Value) + } + return +} + +// PackPairs 需要传入一个长度为2 * n的数组flatPairs,数组可以为nil。 +// +// 函数将会返回一个长度为n的pairs数组,pairs满足 +// 对于所有的 0 <= i < n +// pairs[i].Key == flatPairs[i * 2] +// pairs[i].Value == flatPairs[i * 2 + 1] +// 如果flatPairs为nil,则返回的pairs也为nil +// +// 入参flatPairs需要满足以下条件: +// 对于所有的 0 <= i < n +// flatPairs[i * 2] 的类型为 K +// flatPairs[i * 2 + 1] 的类型为 V +// 否则会panic +func PackPairs[K any, V any](flatPairs []any) (pairs []Pair[K, V]) { + if flatPairs == nil { + return nil + } + n := len(flatPairs) / 2 + pairs = make([]Pair[K, V], n) + for i := 0; i < n; i++ { + pairs[i] = NewPair(flatPairs[i*2].(K), flatPairs[i*2+1].(V)) + } + return +} diff --git a/tuple/pair/pair_test.go b/tuple/pair/pair_test.go new file mode 100644 index 00000000..233d08aa --- /dev/null +++ b/tuple/pair/pair_test.go @@ -0,0 +1,237 @@ +// Copyright 2021 ecodeclub +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pair_test + +import ( + "fmt" + "sort" + "testing" + + "github.com/ecodeclub/ekit/mapx" + "github.com/ecodeclub/ekit/tuple/pair" + "github.com/stretchr/testify/suite" +) + +type testPairSuite struct{ suite.Suite } + +func (s *testPairSuite) TestString() { + { + p := pair.NewPair(100, "23333") + s.Assert().Equal("<100, \"23333\">", p.String()) + } + { + p := pair.NewPair("testStruct", map[int]int{ + 11: 1, + 22: 2, + 33: 3, + }) + s.Assert().Equal("<\"testStruct\", map[int]int{11:1, 22:2, 33:3}>", p.String()) + } +} + +func (s *testPairSuite) TestNewPairs() { + type caseType struct { + // input + keys []int + values []string + // expected + pairs []pair.Pair[int, string] + err error + } + for _, c := range []caseType{ + { + keys: []int{1, 2, 3, 4, 5}, + values: []string{"1", "2", "3", "4", "5"}, + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + err: nil, + }, + { + keys: nil, + values: []string{"1"}, + pairs: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: []int{1}, + values: nil, + pairs: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: nil, + values: nil, + pairs: nil, + err: fmt.Errorf("keys与values均不可为nil"), + }, + { + keys: []int{1, 2}, + values: []string{"1"}, + pairs: nil, + err: fmt.Errorf("keys与values的长度不同, len(keys)=2, len(values)=1"), + }, + } { + pairs, err := pair.NewPairs(c.keys, c.values) + s.Assert().Equal(c.err, err) + s.Assert().EqualValues(c.pairs, pairs) + } +} + +func (s *testPairSuite) TestSplitPairs() { + type caseType struct { + // input + pairs []pair.Pair[int, string] + // expected + keys []int + values []string + } + for _, c := range []caseType{ + { + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + keys: []int{1, 2, 3, 4, 5}, + values: []string{"1", "2", "3", "4", "5"}, + }, + { + pairs: nil, + + keys: nil, + values: nil, + }, + { + pairs: []pair.Pair[int, string]{}, + keys: []int{}, + values: []string{}, + }, + } { + keys, values := pair.SplitPairs(c.pairs) + if c.pairs == nil { + s.Assert().Nil(keys) + s.Assert().Nil(values) + } else { + s.Assert().Len(keys, len(c.pairs)) + s.Assert().Len(values, len(c.pairs)) + for i, pair := range c.pairs { + s.Assert().Equal(pair.Key, keys[i]) + s.Assert().Equal(pair.Value, values[i]) + } + } + } +} + +func (s *testPairSuite) TestFlattenPairs() { + type caseType struct { + pairs []pair.Pair[int, string] + flattPairs []any + } + + for _, c := range []caseType{ + { + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + flattPairs: []any{1, "1", 2, "2", 3, "3", 4, "4", 5, "5"}, + }, + { + pairs: nil, + flattPairs: nil, + }, + { + pairs: []pair.Pair[int, string]{}, + flattPairs: []any{}, + }, + } { + flatPairs := pair.FlattenPairs(c.pairs) + s.Assert().EqualValues(c.flattPairs, flatPairs) + } +} + +func (s *testPairSuite) TestPackPairs() { + type caseType struct { + flattPairs []any + pairs []pair.Pair[int, string] + } + + for _, c := range []caseType{ + { + flattPairs: []any{1, "1", 2, "2", 3, "3", 4, "4", 5, "5"}, + pairs: []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + pair.NewPair(4, "4"), + pair.NewPair(5, "5"), + }, + }, + { + flattPairs: nil, + pairs: nil, + }, + { + flattPairs: []any{}, + pairs: []pair.Pair[int, string]{}, + }, + } { + pairs := pair.PackPairs[int, string](c.flattPairs) + s.Assert().EqualValues(c.pairs, pairs) + } +} + +func (s *testPairSuite) TestMapPairMapping() { + // map to pairs + expectedMap := map[int]string{ + 1: "1", + 2: "2", + 3: "3", + } + expectedPairs := []pair.Pair[int, string]{ + pair.NewPair(1, "1"), + pair.NewPair(2, "2"), + pair.NewPair(3, "3"), + } + + // 可以用这种方式实现map到[]Pair的映射 + pairs, err := pair.NewPairs(mapx.KeysValues(expectedMap)) + s.Assert().Nil(err) + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].Key < pairs[j].Key + }) + s.Assert().EqualValues(expectedPairs, pairs) + + // 可以用这种方式实现[]Pair到map的映射 + mp, err := mapx.ToMap(pair.SplitPairs(expectedPairs)) + s.Assert().Nil(err) + for k, v := range mp { + s.Assert().Equal(expectedMap[k], v) + } +} + +func TestPair(t *testing.T) { + suite.Run(t, new(testPairSuite)) +}