Skip to content

Commit

Permalink
增加Pair类型 (#237)
Browse files Browse the repository at this point in the history
* 增加对Pair类型的支持

* 按照comments修改设计

* 按照review意见修改code

* 修复可能出现的race
  • Loading branch information
dxyinme authored Jan 8, 2024
1 parent 5a23504 commit 5056d18
Show file tree
Hide file tree
Showing 4 changed files with 457 additions and 0 deletions.
27 changes: 27 additions & 0 deletions mapx/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package mapx

import "fmt"

// Keys 返回 map 里面的所有的 key。
// 需要注意:这些 key 的顺序是随机。
func Keys[K comparable, V any](m map[K]V) []K {
Expand Down Expand Up @@ -45,3 +47,28 @@ func KeysValues[K comparable, V any](m map[K]V) ([]K, []V) {
}
return keys, values
}

// ToMap 将会返回一个map[K]V
// 请保证传入的 keys 与 values 长度相同,长度均为n
// 长度不相同或者 keys 或者 values 为nil则会抛出异常
// 返回的 m map[K]V 保证对于所有的 0 <= i < n
// m[keys[i]] = values[i]
//
// 注意:
// 如果传入的数组中存在 0 <= i < j < n使得 keys[i] == keys[j]
// 则在返回的 m 中 m[keys[i]] = values[j]
// 如果keys和values的长度为0,则会返回一个空map
func ToMap[K comparable, V any](keys []K, values []V) (m map[K]V, err error) {
if keys == nil || values == nil {
return nil, fmt.Errorf("keys与values均不可为nil")
}
n := len(keys)
if n != len(values) {
return nil, fmt.Errorf("keys与values的长度不同, len(keys)=%d, len(values)=%d", n, len(values))
}
m = make(map[K]V, n)
for i := 0; i < n; i++ {
m[keys[i]] = values[i]
}
return
}
61 changes: 61 additions & 0 deletions mapx/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package mapx

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -147,3 +148,63 @@ func TestKeysValues(t *testing.T) {
})
}
}

func TestToMap(t *testing.T) {
type caseType struct {
keys []int
values []string

result map[int]string
err error
}
for _, c := range []caseType{
{
keys: []int{1, 2, 3},
values: []string{"1", "2", "3"},
result: map[int]string{
1: "1",
2: "2",
3: "3",
},
err: nil,
},
{
keys: []int{1, 2, 3},
values: []string{"1", "2"},
result: nil,
err: fmt.Errorf("keys与values的长度不同, len(keys)=3, len(values)=2"),
},
{
keys: []int{1, 2, 3},
values: nil,
result: nil,
err: fmt.Errorf("keys与values均不可为nil"),
},
{
keys: nil,
values: []string{"1", "2"},
result: nil,
err: fmt.Errorf("keys与values均不可为nil"),
},
{
keys: nil,
values: nil,
result: nil,
err: fmt.Errorf("keys与values均不可为nil"),
},
{
keys: []int{1, 2, 3, 1, 1},
values: []string{"1", "2", "3", "10", "100"},
result: map[int]string{
1: "100",
2: "2",
3: "3",
},
err: nil,
},
} {
result, err := ToMap(c.keys, c.values)
assert.Equal(t, c.err, err)
assert.Equal(t, c.result, result)
}
}
132 changes: 132 additions & 0 deletions tuple/pair/pair.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2021 ecodeclub
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pair

import (
"fmt"
)

type Pair[K any, V any] struct {
Key K
Value V
}

func (pair *Pair[K, V]) String() string {
return fmt.Sprintf("<%#v, %#v>", pair.Key, pair.Value)
}

// Split 方法将Key, Value作为返回参数传出。
func (pair *Pair[K, V]) Split() (K, V) {
return pair.Key, pair.Value
}

func NewPair[K any, V any](
key K,
value V,
) Pair[K, V] {
return Pair[K, V]{
Key: key,
Value: value,
}
}

// NewPairs 需要传入两个长度相同并且均不为nil的数组 keys 和 values,
// 设keys长度为n,返回一个长度为n的pair数组。
// 保证:
//
// 返回的pair数组满足条件(设pair数组为p):
// 对于所有的 0 <= i < n
// p[i].Key == keys[i] 并且 p[i].Value == values[i]
//
// 如果传入的keys或者values为nil,会返回error
//
// 如果传入的keys长度与values长度不同,会返回error
func NewPairs[K any, V any](
keys []K,
values []V,
) ([]Pair[K, V], error) {
if keys == nil || values == nil {
return nil, fmt.Errorf("keys与values均不可为nil")
}
n := len(keys)
if n != len(values) {
return nil, fmt.Errorf("keys与values的长度不同, len(keys)=%d, len(values)=%d", n, len(values))
}
pairs := make([]Pair[K, V], n)
for i := 0; i < n; i++ {
pairs[i] = NewPair(keys[i], values[i])
}
return pairs, nil
}

// SplitPairs 需要传入一个[]Pair[K, V],数组可以为nil。
// 设pairs数组的长度为n,返回两个长度均为n的数组keys, values。
// 如果pairs数组是nil, 则返回的keys与values也均为nil。
func SplitPairs[K any, V any](pairs []Pair[K, V]) (keys []K, values []V) {
if pairs == nil {
return nil, nil
}
n := len(pairs)
keys = make([]K, n)
values = make([]V, n)
for i, pair := range pairs {
keys[i], values[i] = pair.Split()
}
return
}

// FlattenPairs 需要传入一个[]Pair[K, V],数组可以为nil
// 如果pairs数组为nil,则返回的flatPairs数组也为nil
//
// 设pairs数组长度为n,保证返回的flatPairs数组长度为2 * n且满足:
// 对于所有的 0 <= i < n
// flatPairs[i * 2] == pairs[i].Key
// flatPairs[i * 2 + 1] == pairs[i].Value
func FlattenPairs[K any, V any](pairs []Pair[K, V]) (flatPairs []any) {
if pairs == nil {
return nil
}
n := len(pairs)
flatPairs = make([]any, 0, n*2)
for _, pair := range pairs {
flatPairs = append(flatPairs, pair.Key, pair.Value)
}
return
}

// PackPairs 需要传入一个长度为2 * n的数组flatPairs,数组可以为nil。
//
// 函数将会返回一个长度为n的pairs数组,pairs满足
// 对于所有的 0 <= i < n
// pairs[i].Key == flatPairs[i * 2]
// pairs[i].Value == flatPairs[i * 2 + 1]
// 如果flatPairs为nil,则返回的pairs也为nil
//
// 入参flatPairs需要满足以下条件:
// 对于所有的 0 <= i < n
// flatPairs[i * 2] 的类型为 K
// flatPairs[i * 2 + 1] 的类型为 V
// 否则会panic
func PackPairs[K any, V any](flatPairs []any) (pairs []Pair[K, V]) {
if flatPairs == nil {
return nil
}
n := len(flatPairs) / 2
pairs = make([]Pair[K, V], n)
for i := 0; i < n; i++ {
pairs[i] = NewPair(flatPairs[i*2].(K), flatPairs[i*2+1].(V))
}
return
}
Loading

0 comments on commit 5056d18

Please sign in to comment.