Skip to content

Commit

Permalink
enhance: support Int8Vector (#2611)
Browse files Browse the repository at this point in the history
Issue: milvus-io/milvus#38666

Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain authored Feb 7, 2025
1 parent 7ee1527 commit 768a2dd
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 2 deletions.
72 changes: 72 additions & 0 deletions examples/datatypes/int8_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import time
import random
import numpy as np
from pymilvus import (
connections,
utility,
FieldSchema, CollectionSchema, DataType,
Collection,
)
from pymilvus import MilvusClient

int8_index_types = ["HNSW"]

default_int8_index_params = [{"M": 8, "efConstruction": 200}]


def gen_int8_vectors(num, dim):
raw_vectors = []
int8_vectors = []
for _ in range(num):
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
raw_vectors.append(raw_vector)
int8_vector = np.array(raw_vector, dtype=np.int8)
int8_vectors.append(int8_vector)
return raw_vectors, int8_vectors


def int8_vector_search():
connections.connect()

int64_field = FieldSchema(name="int64", dtype=DataType.INT64, is_primary=True, auto_id=True)
dim = 128
nb = 3000
vector_field_name = "int8_vector"
int8_vector = FieldSchema(name=vector_field_name, dtype=DataType.INT8_VECTOR, dim=dim)
schema = CollectionSchema(fields=[int64_field, int8_vector])

if utility.has_collection("hello_milvus_int8"):
utility.drop_collection("hello_milvus_int8")

hello_milvus = Collection("hello_milvus_int8", schema)

_, vectors = gen_int8_vectors(nb, dim)
hello_milvus.insert([vectors[:6]])
rows = [
{vector_field_name: vectors[6]},
{vector_field_name: vectors[7]},
{vector_field_name: vectors[8]},
{vector_field_name: vectors[9]},
{vector_field_name: vectors[10]},
{vector_field_name: vectors[11]},
]
hello_milvus.insert(rows)
hello_milvus.flush()

for i, index_type in enumerate(int8_index_types):
index_params = default_int8_index_params[i]
hello_milvus.create_index(vector_field_name,
index_params={"index_type": index_type, "params": index_params, "metric_type": "L2"})
hello_milvus.load()
print("index_type = ", index_type)
res = hello_milvus.search(vectors[0:10], vector_field_name, {"metric_type": "L2"}, limit=1, output_fields=["int8_vector"])
print("raw bytes: ", res[0][0].get("float16_vector"))
print("numpy ndarray: ", np.frombuffer(res[0][0].get("int8_vector"), dtype=np.int8))
hello_milvus.release()
hello_milvus.drop_index()

hello_milvus.drop()


if __name__ == "__main__":
int8_vector_search()
10 changes: 9 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def get_fields_by_range(
field_meta.vectors.dim = dim
if dtype == DataType.FLOAT_VECTOR:
if start == 0 and (end - start) * dim >= len(vectors.float_vector.data):
# If the range equals to the lenth of ectors.float_vector.data, direct return
# If the range equals to the length of vectors.float_vector.data, direct return
# it to avoid a copy. This logic improves performance by 25% for the case
# retrival 1536 dim embeddings with topk=16384.
field2data[name] = vectors.float_vector.data, field_meta
Expand Down Expand Up @@ -648,6 +648,13 @@ def get_fields_by_range(
field_meta,
)
continue

if dtype == DataType.INT8_VECTOR:
field2data[name] = (
vectors.int8_vector[start * dim : end * dim],
field_meta,
)
continue
return field2data

def __iter__(self) -> SequenceIterator:
Expand Down Expand Up @@ -706,6 +713,7 @@ def __init__(
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.INT8_VECTOR,
):
dim = field_meta.vectors.dim
if field_meta.type in [DataType.BINARY_VECTOR]:
Expand Down
26 changes: 26 additions & 0 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,23 @@ def pack_field_value_to_field_data(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "sparse_float_vector", type(field_value))
) from e
elif field_type == DataType.INT8_VECTOR:
try:
i_value = field_value
if isinstance(field_value, np.ndarray):
if field_value.dtype != "int8":
raise ParamError(
message="invalid input for int8 vector. Expected an np.ndarray with dtype=int8"
)
v_bytes = field_value.view(np.int8).tobytes()

field_data.vectors.dim = len(i_value)
field_data.vectors.int8_vector += v_bytes
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "int8_vector", type(field_value))
) from e
elif field_type == DataType.VARCHAR:
try:
if field_value is None:
Expand Down Expand Up @@ -561,6 +578,15 @@ def entity_to_field_data(entity: Any, field_info: Any, num_rows: int):
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "sparse_float_vector", type(entity.get("values")[0]))
) from e
elif entity_type == DataType.INT8_VECTOR:
try:
field_data.vectors.dim = len(entity.get("values")[0])
field_data.vectors.int8_vector = b"".join(entity.get("values"))
except (TypeError, ValueError) as e:
raise DataNotMatchException(
message=ExceptionsMessage.FieldDataInconsistent
% (field_name, "int8_vector", type(entity.get("values")[0]))
) from e
else:
raise ParamError(message=f"Unsupported data type: {entity_type}")

Expand Down
1 change: 1 addition & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,7 @@ def create_index(
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.SPARSE_FLOAT_VECTOR,
DataType.INT8_VECTOR,
}:
break

Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,9 @@ def _prepare_placeholder_str(cls, data: Any):
elif dtype in ("float32", "float64"):
pl_type = PlaceholderType.FloatVector
pl_values = (blob.vector_float_to_bytes(entity) for entity in data)
elif dtype == "int8":
pl_type = PlaceholderType.Int8Vector
pl_values = (array.tobytes() for array in data)

elif dtype == "byte":
pl_type = PlaceholderType.BinaryVector
Expand Down
2 changes: 2 additions & 0 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class DataType(IntEnum):
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103
SPARSE_FLOAT_VECTOR = 104
INT8_VECTOR = 105

UNKNOWN = 999

Expand Down Expand Up @@ -179,6 +180,7 @@ class PlaceholderType(IntEnum):
FLOAT16_VECTOR = 102
BFLOAT16_VECTOR = 103
SparseFloatVector = 104
Int8Vector = 105
VARCHAR = 21


Expand Down
5 changes: 5 additions & 0 deletions pymilvus/orm/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def _parse_type_params(self):
DataType.VARCHAR,
DataType.ARRAY,
DataType.SPARSE_FLOAT_VECTOR,
DataType.INT8_VECTOR,
):
return
if not self._kwargs:
Expand Down Expand Up @@ -819,12 +820,15 @@ def prepare_fields_from_dataframe(df: pd.DataFrame):
DataType.FLOAT_VECTOR,
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.INT8_VECTOR,
):
vector_type_params = {}
if new_dtype == DataType.BINARY_VECTOR:
vector_type_params["dim"] = len(values[i]) * 8
elif new_dtype in (DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR):
vector_type_params["dim"] = int(len(values[i]) // 2)
elif new_dtype == DataType.INT8_VECTOR:
vector_type_params["dim"] = len(values[i])
else:
vector_type_params["dim"] = len(values[i])
column_params_map[col_names[i]] = vector_type_params
Expand All @@ -849,6 +853,7 @@ def check_schema(schema: CollectionSchema):
DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.SPARSE_FLOAT_VECTOR,
DataType.INT8_VECTOR,
):
vector_fields.append(field.name)
if len(vector_fields) < 1:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ def test_search_result_with_fields_data(self, pk):
bfloat16_vector=os.urandom(32),
),
),
schema_pb2.FieldData(type=DataType.INT8_VECTOR, field_name="int8_vector_field", field_id=117,
vectors=schema_pb2.VectorField(
dim=16,
int8_vector=os.urandom(32),
),
),
]
result = schema_pb2.SearchResultData(
fields_data=fields_data,
Expand All @@ -204,3 +210,4 @@ def test_search_result_with_fields_data(self, pk):
assert [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] == r[0][1].int64_array_field
assert 32 == len(r[0][0].entity.bfloat16_vector_field)
assert 32 == len(r[0][0].entity.float16_vector_field)
assert 16 == len(r[0][0].entity.int8_vector_field)
4 changes: 3 additions & 1 deletion tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def test_collection_by_DataFrame(self):
FieldSchema("int64", DataType.INT64),
FieldSchema("float", DataType.FLOAT),
FieldSchema("float_vector", DataType.FLOAT_VECTOR, dim=128),
FieldSchema("binary_vector", DataType.BINARY_VECTOR, dim=128),
FieldSchema("float16_vector", DataType.FLOAT16_VECTOR, dim=128),
FieldSchema("bfloat16_vector", DataType.BFLOAT16_VECTOR, dim=128)
FieldSchema("bfloat16_vector", DataType.BFLOAT16_VECTOR, dim=128),
FieldSchema("int8_vector", DataType.INT8_VECTOR, dim=128),
]

prefix = "pymilvus.client.grpc_handler.GrpcHandler"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def raw_dict_bfloat16_vector(self):
_dict["params"] = {"dim": 128}
return _dict

@pytest.fixture(scope="function")
def raw_dict_int8_vector(self):
_dict = dict()
_dict["name"] = "TestFieldSchema_name_int8_vector"
_dict["description"] = "TestFieldSchema_description_int8_vector"
_dict["type"] = DataType.INT8_VECTOR
_dict["params"] = {"dim": 128}
return _dict

@pytest.fixture(scope="function")
def raw_dict_norm(self):
_dict = dict()
Expand Down Expand Up @@ -143,6 +152,14 @@ def test_constructor_from_bfloat16_dict(self, raw_dict_bfloat16_vector):
assert field.name == raw_dict_bfloat16_vector['name']
assert field.dim == raw_dict_bfloat16_vector['params']['dim']

def test_constructor_from_int8_dict(self, raw_dict_int8_vector):
field = FieldSchema.construct_from_dict(raw_dict_int8_vector)
assert field.dtype == DataType.INT8_VECTOR
assert field.description == raw_dict_int8_vector['description']
assert field.is_primary is False
assert field.name == raw_dict_int8_vector['name']
assert field.dim == raw_dict_int8_vector['params']['dim']

def test_constructor_from_norm_dict(self, raw_dict_norm):
field = FieldSchema.construct_from_dict(raw_dict_norm)
assert field.dtype == DataType.INT64
Expand Down
1 change: 1 addition & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TestTypes:
([np.int8(1)], DataType.FLOAT_VECTOR),
([np.float16(1.0)], DataType.FLOAT16_VECTOR),
# ([np.array([1, 1], dtype=bfloat16)], DataType.BFLOAT16_VECTOR),
([np.int8(1)], DataType.INT8_VECTOR),
])
def test_infer_dtype_bydata(self, input_expect):
data, expect = input_expect
Expand Down

0 comments on commit 768a2dd

Please sign in to comment.