Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
emekaokoli19 committed May 6, 2024
1 parent 5e2a592 commit fa5f40b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 84 deletions.
67 changes: 42 additions & 25 deletions src/vdf_io/export_vdf/weaviate_export.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os

from tqdm import tqdm
import weaviate
import json

from tqdm import tqdm
from weaviate.classes.query import MetadataQuery
from vdf_io.export_vdf.vdb_export_cls import ExportVDB
from vdf_io.meta_types import NamespaceMeta
from vdf_io.names import DBNames
from vdf_io.util import set_arg_from_input, set_arg_from_password
from vdf_io.constants import DEFAULT_BATCH_SIZE
from typing import Dict, List

# Set these environment variables
Expand All @@ -28,9 +29,13 @@ def make_parser(cls, subparsers):
parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
parser_weaviate.add_argument("--openai_api_key", type=str, help="Openai API key")
parser_weaviate.add_arguments(
parser_weaviate.add_argument(
"--batch_size", type=int, help="batch size for fetching",
default=1000
default=DEFAULT_BATCH_SIZE
)
parser_weaviate.add_argument(
"--offset", type=int, help="offset for fetching",
default=None
)
parser_weaviate.add_argument(
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
Expand Down Expand Up @@ -100,39 +105,50 @@ def get_index_names(self):
)
return [c for c in self.all_classes if c in input_classes]

def metadata_to_dict(self, metadata):
meta_data = {}
meta_data["creation_time"] = metadata.creation_time
meta_data["distance"] = metadata.distance
meta_data["certainty"] = metadata.certainty
meta_data["explain_score"] = metadata.explain_score
meta_data["is_consistent"] = metadata.is_consistent
meta_data["last_update_time"] = metadata.last_update_time
meta_data["rerank_score"] = metadata.rerank_score
meta_data["score"] = metadata.score

return meta_data

def get_data(self):
# Get the index names to export
index_names = self.get_index_names()
index_metas: Dict[str, List[NamespaceMeta]] = {}

# Export data in batches
batch_size = self.args["batch_size"]
offset = self.args["offset"]

# Iterate over index names and fetch data
for index_name in index_names:
collection = self.client.collections.get(index_name)
response = collection.aggregate.over_all(total_count=True)
total_vector_count = response.total_count
response = collection.query.fetch_objects(
limit=batch_size,
offset=offset,
include_vector=True,
return_metadata=MetadataQuery.full()
)
res = collection.aggregate.over_all(total_count=True)
total_vector_count = res.total_count

# Create vectors directory for this index
vectors_directory = self.create_vec_dir(index_name)

# Export data in batches
batch_size = self.args["batch_size"]
num_batches = (total_vector_count + batch_size - 1) // batch_size
num_vectors_exported = 0

for batch_idx in tqdm(range(num_batches), desc=f"Exporting {index_name}"):
offset = batch_idx * batch_size
objects = collection.objects.limit(batch_size).offset(offset).get()

# Extract vectors and metadata
vectors = {obj.id: obj.vector for obj in objects}
metadata = {}
# Need a better way
for obj in objects:
metadata[obj.id] = {attr: getattr(obj, attr) for attr in dir(obj) if not attr.startswith("__")}

for obj in response.objects:
vectors = obj.vector
metadata = obj.metadata
metadata = self.metadata_to_dict(metadata=metadata)

# Save vectors and metadata to Parquet file
num_vectors_exported += self.save_vectors_to_parquet(
num_vectors_exported = self.save_vectors_to_parquet(
vectors, metadata, vectors_directory
)

Expand All @@ -143,7 +159,7 @@ def get_data(self):
vectors_directory,
total=total_vector_count,
num_vectors_exported=num_vectors_exported,
dim=300, # Not sure of the dimensions
dim=-1,
distance="Cosine",
)
]
Expand All @@ -154,7 +170,8 @@ def get_data(self):
internal_metadata = self.get_basic_vdf_meta(index_metas)
meta_text = json.dumps(internal_metadata.model_dump(), indent=4)
tqdm.write(meta_text)

with open(os.path.join(self.vdf_directory, "VDF_META.json"), "w") as json_file:
json_file.write(meta_text)
print("Data export complete.")

return True
119 changes: 60 additions & 59 deletions src/vdf_io/import_vdf/weaviate_import.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import weaviate
import json
from tqdm import tqdm
from vdf_io.import_vdf.vdf_import_cls import ImportVDB
from vdf_io.names import DBNames
from vdf_io.util import set_arg_from_input, set_arg_from_password
from vdf_io.constants import INT_MAX, DEFAULT_BATCH_SIZE

# Set these environment variables
URL = os.getenv("YOUR_WCS_URL")
Expand All @@ -25,6 +25,14 @@ def make_parser(cls, subparsers):
parser_weaviate.add_argument(
"--index_name", type=str, help="Name of the index in Weaviate"
)
parser_weaviate.add_argument(
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
help="Type of connection to Weaviate (local or cloud)"
)
parser_weaviate.add_argument(
"--batch_size", type=int, help="batch size for fetching",
default=DEFAULT_BATCH_SIZE
)

@classmethod
def import_vdb(cls, args):
Expand All @@ -34,18 +42,24 @@ def import_vdb(cls, args):
"Enter the URL of Weaviate instance: ",
str,
)
set_arg_from_password(
args,
"api_key",
"Enter the Weaviate API key: ",
"WEAVIATE_API_KEY",
)
set_arg_from_input(
args,
"index_name",
"Enter the name of the index in Weaviate: ",
str,
)
set_arg_from_input(
args,
"connection_type",
"Enter 'local' or 'cloud' for connection types: ",
choices=['local', 'cloud'],
)
set_arg_from_password(
args,
"api_key",
"Enter the Weaviate API key: ",
"WEAVIATE_API_KEY",
)
weaviate_import = ImportWeaviate(args)
weaviate_import.upsert_data()
return weaviate_import
Expand Down Expand Up @@ -76,7 +90,6 @@ def upsert_data(self):

# Create or get the index
index_name = self.create_new_name(index_name, self.client.collections.list_all().keys())
index = self.client.collections.get(index_name)

# Load data from the Parquet files
data_path = namespace_meta["data_path"]
Expand All @@ -85,55 +98,43 @@ def upsert_data(self):

vectors = {}
metadata = {}
vector_column_names, vector_column_name = self.get_vector_column_name(
index_name, namespace_meta
)

# for file in tqdm(parquet_files, desc="Loading data from parquet files"):
# file_path = os.path.join(final_data_path, file)
# df = self.read_parquet_progress(file_path)

# if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
# max_hit = True
# break

# self.update_vectors(vectors, vector_column_name, df)
# self.update_metadata(metadata, vector_column_names, df)
# if max_hit:
# break

# tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")

# # Upsert the vectors and metadata to the Weaviate index in batches
# BATCH_SIZE = self.args.get("batch_size", 1000) or 1000
# current_batch_size = BATCH_SIZE
# start_idx = 0

# while start_idx < len(vectors):
# end_idx = min(start_idx + current_batch_size, len(vectors))

# batch_vectors = [
# (
# str(id),
# vector,
# {
# k: v
# for k, v in metadata.get(id, {}).items()
# if v is not None
# } if len(metadata.get(id, {}).keys()) > 0 else None
# )
# for id, vector in list(vectors.items())[start_idx:end_idx]
# ]

# try:
# resp = index.batch.create(batch_vectors)
# total_imported_count += len(batch_vectors)
# start_idx += len(batch_vectors)
# except Exception as e:
# tqdm.write(f"Error upserting vectors for index '{index_name}', {e}")
# if current_batch_size < BATCH_SIZE / 100:
# tqdm.write("Batch size is not the issue. Aborting import")
# raise e
# current_batch_size = int(2 * current_batch_size / 3)
# tqdm.write(f"Reducing batch size to {current_batch_size}")
# continue

# tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
# self.args["imported_count"] = total_imported_count
for file in tqdm(parquet_files, desc="Loading data from parquet files"):
file_path = os.path.join(final_data_path, file)
df = self.read_parquet_progress(file_path)

if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
max_hit = True
break
if len(vectors) + len(df) > (
self.args.get("max_num_rows") or INT_MAX
):
df = df.head(
(self.args.get("max_num_rows") or INT_MAX) - len(vectors)
)
max_hit = True
self.update_vectors(vectors, vector_column_name, df)
self.update_metadata(metadata, vector_column_names, df)
if max_hit:
break

tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")

# Upsert the vectors and metadata to the Weaviate index in batches
BATCH_SIZE = self.args.get("batch_size")

with self.client.batch.fixed_size(batch_size=BATCH_SIZE) as batch:
for _, vector in vectors.items():
batch.add_object(
vector=vector,
collection=index_name
#TODO: Find way to add Metadata
)
total_imported_count += 1


tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
self.args["imported_count"] = total_imported_count

0 comments on commit fa5f40b

Please sign in to comment.