Skip to content

Commit

Permalink
[test]Add test for array data type
Browse files Browse the repository at this point in the history
Signed-off-by: zhuwenxing <[email protected]>
  • Loading branch information
zhuwenxing committed Dec 12, 2023
1 parent fca8663 commit 36dd348
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 17 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ jobs:
shell: bash
working-directory: tests
run: |
pytest -s -v --tags ${{ matrix.case_tag }} -n 4
pytest -s -v --tags ${{ matrix.case_tag }} -n 2
- name: Get Milvus status
shell: bash
Expand All @@ -518,7 +518,7 @@ jobs:
if: ${{ ! success() }}
uses: actions/upload-artifact@v2
with:
name: api-test-logs-${{ matrix.deploy_tools }}-${{ matrix.milvus_mode }}
name: api-test-logs-${{ matrix.deploy_tools }}-${{ matrix.milvus_mode }}-${{ matrix.case_tag }}
path: |
./logs
./server.log
Expand Down
26 changes: 17 additions & 9 deletions tests/base/client_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys
import time

from pymilvus import DefaultConfig, DataType, db

sys.path.append("..")
Expand Down Expand Up @@ -340,12 +342,14 @@ def is_binary_by_schema(self, schema):

def compare_collections(self, src_name, dist_name, output_fields=None):
if output_fields is None:
output_fields = [ct.default_int64_field_name, ct.default_json_field_name]
output_fields = ["*"]
collection_src, _ = self.collection_wrap.init_collection(name=src_name)
collection_dist, _ = self.collection_wrap.init_collection(name=dist_name)
assert collection_src.num_entities == collection_dist.num_entities, \
f"collection_src {src_name} num_entities: {collection_src.num_entities} != " \
f"collection_dist {dist_name} num_entities: {collection_dist.num_entities}"
log.info(f"collection_src schema: {collection_src.schema}")
log.info(f"collection_dist schema: {collection_dist.schema}")
assert collection_src.schema == collection_dist.schema
# get partitions
partitions_src = collection_src.partitions
Expand All @@ -355,18 +359,22 @@ def compare_collections(self, src_name, dist_name, output_fields=None):

for coll in [collection_src, collection_dist]:
is_binary = self.is_binary_by_schema(coll.schema)
if is_binary:
coll.create_index(ct.default_binary_vec_field_name, ct.default_bin_flat_index,
index_name=cf.gen_unique_str())
else:
coll.create_index(ct.default_float_vec_field_name, ct.default_index, index_name=cf.gen_unique_str())
coll.load()
try:
if is_binary:
coll.create_index(ct.default_binary_vec_field_name, ct.default_bin_flat_index,
index_name=cf.gen_unique_str())
else:
coll.create_index(ct.default_float_vec_field_name, ct.default_index, index_name=cf.gen_unique_str())
except Exception as e:
log.error(f"collection {coll.name} create index failed with error: {e}")
coll.load(timeout=120)
time.sleep(5)
src_res = collection_src.query(expr=f'{ct.default_int64_field_name} >= 0',
output_fields=output_fields)
log.info(f"src res: {len(src_res)}")
log.info(f"src res: {len(src_res)}, src res: {src_res[-1]}")
dist_res = collection_dist.query(expr=f'{ct.default_int64_field_name} >= 0',
output_fields=output_fields)
log.info(f"dist res: {len(dist_res)}")
log.info(f"dist res: {len(dist_res)}, dist res: {dist_res[-1]}")
assert len(dist_res) == len(src_res)

def check_collection_binary(self, name):
Expand Down
5 changes: 5 additions & 0 deletions tests/common/common_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def gen_json_field(name=ct.default_json_field_name, is_primary=False, descriptio
description=description, is_primary=is_primary)
return json_field

def gen_array_field(name=ct.default_array_field_name, is_primary=False, element_type=DataType.VARCHAR ,description=ct.default_desc):
array_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.ARRAY,
description=description, is_primary=is_primary, element_type=element_type, max_capacity=2000, max_length=1500)
return array_field


def gen_float_vec_field(name=ct.default_float_vec_field_name, is_primary=False, dim=ct.default_dim,
description=ct.default_desc):
Expand Down
5 changes: 3 additions & 2 deletions tests/common/common_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
default_double_field_name = "double"
default_string_field_name = "varchar"
default_json_field_name = "json"
default_array_field_name = "array"
default_float_vec_field_name = "float_vector"
another_float_vec_field_name = "float_vector1"
default_binary_vec_field_name = "binary_vector"
Expand Down Expand Up @@ -73,8 +74,8 @@
err_msg = "err_msg"
in_cluster_env = "IN_CLUSTER"

default_flat_index = {"index_type": "FLAT", "params": {}, "metric_type": "L2"}
default_bin_flat_index = {"index_type": "BIN_FLAT", "params": {}, "metric_type": "JACCARD"}
default_flat_index = {"index_type": "IVF_SQ8", "metric_type": "COSINE", "params": {"nlist": 64}}
default_bin_flat_index = {"index_type": "BIN_IVF_FLAT", "params": {"nlist": 128}, "metric_type": "JACCARD"}

"""" List of parameters used to pass """
get_invalid_strs = [
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pytest-print==0.2.1
pytest-level==0.1.1
pytest-xdist==2.5.0
pytest-loguru==0.2.0
pymilvus==2.2.9.dev18
pymilvus==2.3.2
pytest-rerunfailures==9.1.1
git+https://github.com/Projectplace/pytest-tags
ndg-httpsclient
Expand Down
129 changes: 126 additions & 3 deletions tests/testcases/test_restore_backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import numpy as np
from collections import defaultdict
from pymilvus import db, list_collections, Collection
from pymilvus import db, list_collections, Collection, DataType
from base.client_base import TestcaseBase
from common import common_func as cf
from common import common_type as ct
Expand Down Expand Up @@ -48,7 +48,6 @@ def test_milvus_restore_back(self, collection_type, collection_need_to_restore,
res, _ = self.utility_wrap.has_collection(name)
assert res is True
# create backup

names_need_backup = names_origin
payload = {"async": False, "backup_name": back_up_name, "collection_names": names_need_backup}
res = client.create_backup(payload)
Expand Down Expand Up @@ -137,7 +136,7 @@ def test_milvus_restore_back_with_multi_partition(self, collection_type, collect
assert name + suffix in res
for name in restore_collections:
self.compare_collections(name, name+suffix)

@pytest.mark.tags(CaseLabel.L1)
def test_milvus_restore_back_with_db_support(self):
# prepare data
Expand Down Expand Up @@ -322,3 +321,127 @@ def test_milvus_restore_with_db_collections(self, drop_db, str_json):
assert collection_name + suffix in res
if not drop_db:
self.compare_collections(collection_name, collection_name + suffix)

@pytest.mark.parametrize("include_partition_key", [True, False])
@pytest.mark.parametrize("include_dynamic", [True, False])
@pytest.mark.tags(CaseLabel.L1)
def test_milvus_restore_back_with_array_datatype(self, include_dynamic, include_partition_key):
self._connect()
name_origin = cf.gen_unique_str(prefix)
back_up_name = cf.gen_unique_str(backup_prefix)
fields = [cf.gen_int64_field(name="int64", is_primary=True),
cf.gen_int64_field(name="key"),
cf.gen_json_field(name="json"),
cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR),
cf.gen_array_field(name="int_array", element_type=DataType.INT64),
cf.gen_float_vec_field(name="float_vector", dim=128),
]
if include_partition_key:
partition_key = "key"
default_schema = cf.gen_collection_schema(fields,
enable_dynamic_field=include_dynamic,
partition_key_field=partition_key)
else:
default_schema = cf.gen_collection_schema(fields,
enable_dynamic_field=include_dynamic)

collection_w = self.init_collection_wrap(name=name_origin, schema=default_schema, active_trace=True)
nb = 3000
data = [
[i for i in range(nb)],
[i % 3 for i in range(nb)],
[{f"key_{str(i)}": i} for i in range(nb)],
[[str(x) for x in range(10)] for i in range(nb)],
[[int(x) for x in range(10)] for i in range(nb)],
[[np.float32(i) for i in range(128)] for _ in range(nb)],
]
collection_w.insert(data=data)
if include_dynamic:
data = [
{
"int64": i,
"key": i % 3,
"json": {f"key_{str(i)}": i},
"var_array": [str(x) for x in range(10)],
"int_array": [int(x) for x in range(10)],
"float_vector": [np.float32(i) for i in range(128)],
f"dynamic_{str(i)}": i
} for i in range(nb, nb*2)
]
collection_w.insert(data=data)
res = client.create_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin]})
log.info(f"create_backup {res}")
res = client.list_backup()
log.info(f"list_backup {res}")
if "data" in res:
all_backup = [r["name"] for r in res["data"]]
else:
all_backup = []
assert back_up_name in all_backup
backup = client.get_backup(back_up_name)
assert backup["data"]["name"] == back_up_name
backup_collections = [backup["collection_name"]for backup in backup["data"]["collection_backups"]]
assert name_origin in backup_collections
res = client.restore_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin],
"collection_suffix": suffix})
log.info(f"restore_backup: {res}")
res, _ = self.utility_wrap.list_collections()
assert name_origin + suffix in res
output_fields = None
self.compare_collections(name_origin, name_origin + suffix, output_fields=output_fields)
res = client.delete_backup(back_up_name)
res = client.list_backup()
if "data" in res:
all_backup = [r["name"] for r in res["data"]]
else:
all_backup = []
assert back_up_name not in all_backup

@pytest.mark.tags(CaseLabel.L1)
def test_milvus_restore_back_with_delete(self):
self._connect()
name_origin = cf.gen_unique_str(prefix)
back_up_name = cf.gen_unique_str(backup_prefix)
fields = [cf.gen_int64_field(name="int64", is_primary=True),
cf.gen_int64_field(name="key"),
cf.gen_json_field(name="json"),
cf.gen_array_field(name="var_array", element_type=DataType.VARCHAR),
cf.gen_array_field(name="int_array", element_type=DataType.INT64),
cf.gen_float_vec_field(name="float_vector", dim=128),
]
default_schema = cf.gen_collection_schema(fields)
collection_w = self.init_collection_wrap(name=name_origin, schema=default_schema, active_trace=True)
nb = 3000
data = [
[i for i in range(nb)],
[i % 3 for i in range(nb)],
[{f"key_{str(i)}": i} for i in range(nb)],
[[str(x) for x in range(10)] for i in range(nb)],
[[int(x) for x in range(10)] for i in range(nb)],
[[np.float32(i) for i in range(128)] for _ in range(nb)],
]
res, result = collection_w.insert(data=data)
pk = res.primary_keys
# delete first 100 rows
delete_ids = pk[:100]
collection_w.delete(expr=f"int64 in {delete_ids}")
res = client.create_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin]})
log.info(f"create_backup {res}")
res = client.list_backup()
log.info(f"list_backup {res}")
if "data" in res:
all_backup = [r["name"] for r in res["data"]]
else:
all_backup = []
assert back_up_name in all_backup
backup = client.get_backup(back_up_name)
assert backup["data"]["name"] == back_up_name
backup_collections = [backup["collection_name"]for backup in backup["data"]["collection_backups"]]
assert name_origin in backup_collections
res = client.restore_backup({"async": False, "backup_name": back_up_name, "collection_names": [name_origin],
"collection_suffix": suffix})
log.info(f"restore_backup: {res}")
res, _ = self.utility_wrap.list_collections()
assert name_origin + suffix in res
output_fields = None
self.compare_collections(name_origin, name_origin + suffix, output_fields=output_fields)

0 comments on commit 36dd348

Please sign in to comment.