Skip to content

Commit

Permalink
Vector store, autodetect mode (#65)
Browse files Browse the repository at this point in the history
* content_field parameter all the way to encoders; setup_mode uses _NOT_SET

* vectorstore and decoders have 'ignore_invalid_documents' settings

* WIP with missing flat encoders

* encoders impl complete (wip)

* tests for flat encoders; logging in vectorstore; fixes in encoders

* refactor into autodetect utils module; add autodetect unit tests

* integration test of autodetect + compatibility with default schema

* rename encoder->codec throughout

* simplified counter in autodetect; warning made into logger warning; reverted docstring on setup_mode

* setup_mode simplified not to use NOT_SET

* content_field not using NOT_SET anymore; adjusted unit tests to new warnings/errors

* autodetect inferences unit tests made into parametrized tests

* More clarity in autodetect unit test assets

* finalize docstring

* bump to 0.4.0 ; docstring style

* vstore: update docs URLs + backticks fixes + docstring for from_documents
  • Loading branch information
hemidactylus authored Sep 9, 2024
1 parent cc40f28 commit 2e66453
Show file tree
Hide file tree
Showing 14 changed files with 2,412 additions and 905 deletions.
39 changes: 39 additions & 0 deletions libs/astradb/langchain_astradb/utils/astradb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
# Thread/coroutine count for one-doc-at-a-time deletes:
MAX_CONCURRENT_DOCUMENT_DELETIONS = 20

# Amount of (max) number of documents for surveying a collection
SURVEY_NUMBER_OF_DOCUMENTS = 15

logger = logging.getLogger()


Expand All @@ -46,6 +49,42 @@ class SetupMode(Enum):
OFF = 3


def _survey_collection(
collection_name: str,
*,
token: str | TokenProvider | None = None,
api_endpoint: str | None = None,
environment: str | None = None,
astra_db_client: AstraDB | None = None,
async_astra_db_client: AsyncAstraDB | None = None,
namespace: str | None = None,
) -> tuple[CollectionDescriptor | None, list[dict[str, Any]]]:
"""Return the collection descriptor (if found) and a sample of documents."""
_environment = _AstraDBEnvironment(
token=token,
api_endpoint=api_endpoint,
environment=environment,
astra_db_client=astra_db_client,
async_astra_db_client=async_astra_db_client,
namespace=namespace,
)
descriptors = [
coll_d
for coll_d in _environment.database.list_collections()
if coll_d.name == collection_name
]
if not descriptors:
return None, []
descriptor = descriptors[0]
# fetch some documents
document_ite = _environment.database.get_collection(collection_name).find(
filter={},
projection={"*": True},
limit=SURVEY_NUMBER_OF_DOCUMENTS,
)
return (descriptor, list(document_ite))


class _AstraDBEnvironment:
def __init__(
self,
Expand Down
200 changes: 0 additions & 200 deletions libs/astradb/langchain_astradb/utils/encoders.py

This file was deleted.

130 changes: 130 additions & 0 deletions libs/astradb/langchain_astradb/utils/vector_store_autodetect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Utilities for AstraDB vector store autodetect mode."""

from __future__ import annotations

import logging
from collections import Counter
from operator import itemgetter
from typing import (
Any,
)

from langchain_astradb.utils.vector_store_codecs import (
_AstraDBVectorStoreDocumentCodec,
_DefaultVectorizeVSDocumentCodec,
_DefaultVSDocumentCodec,
_FlatVectorizeVSDocumentCodec,
_FlatVSDocumentCodec,
)

logger = logging.getLogger(__name__)


def _detect_document_flatness(document: dict[str, Any]) -> bool | None:
"""Try to guess, when possible, if this document has metadata-as-a-dict or not."""
_metadata = document.get("metadata")
_vector = document.get("$vector")
_regularfields = set(document.keys()) - {"_id", "$vector"}
_regularfields_m = _regularfields - {"metadata"}
# cannot determine if ...
if _vector is None:
return None
# now a determination
if isinstance(_metadata, dict) and _regularfields_m:
return False
if isinstance(_metadata, dict) and not _regularfields_m:
# this document should not contribute to the survey
return None
str_regularfields = {
k for k, v in document.items() if isinstance(v, str) if k in _regularfields
}
if str_regularfields:
# Note: even if the only string field is "metadata"
return True
return None


def _detect_documents_flatness(documents: list[dict[str, Any]]) -> bool:
flatness_survey = [_detect_document_flatness(document) for document in documents]
n_flats = flatness_survey.count(True)
n_deeps = flatness_survey.count(False)
if n_flats > 0 and n_deeps > 0:
msg = "Mixed document shapes detected on collection during autodetect."
raise ValueError(msg)

# in absence of clues, 0 < 0 is False and default is NON FLAT (i.e. native)
return n_deeps < n_flats


def _detect_document_content_field(document: dict[str, Any]) -> str | None:
"""Try to guess the content field by inspecting the passed document."""
strlen_map = {
k: len(v) for k, v in document.items() if k != "_id" if isinstance(v, str)
}
if not strlen_map:
return None
return sorted(strlen_map.items(), key=itemgetter(1), reverse=True)[0][0]


def _detect_documents_content_field(
documents: list[dict[str, Any]],
requested_content_field: str,
) -> str:
if requested_content_field == "*":
# guess content_field by docs inspection
content_fields = [
_detect_document_content_field(document) for document in documents
]
valid_content_fields = [cf for cf in content_fields if cf is not None]
logger.info(
"vector store autodetect: inferring content_field from %i documents",
len(valid_content_fields),
)
cf_stats = Counter(valid_content_fields)
if not cf_stats:
msg = "Could not infer content_field name from sampled documents."
raise ValueError(msg)
return cf_stats.most_common(1)[0][0]

return requested_content_field


def _detect_document_codec(
documents: list[dict[str, Any]],
*,
has_vectorize: bool,
ignore_invalid_documents: bool,
norm_content_field: str,
) -> _AstraDBVectorStoreDocumentCodec:
logger.info("vector store autodetect: inspecting %i documents", len(documents))
# survey and determine flatness
is_flat = _detect_documents_flatness(documents)
logger.info("vector store autodetect: is_flat = %s", is_flat)

final_content_field = _detect_documents_content_field(
documents=documents,
requested_content_field=norm_content_field,
)
logger.info(
"vector store autodetect: final_content_field = %s", final_content_field
)

if has_vectorize:
if is_flat:
return _FlatVectorizeVSDocumentCodec(
ignore_invalid_documents=ignore_invalid_documents,
)

return _DefaultVectorizeVSDocumentCodec(
ignore_invalid_documents=ignore_invalid_documents,
)
# no vectorize:
if is_flat:
return _FlatVSDocumentCodec(
content_field=final_content_field,
ignore_invalid_documents=ignore_invalid_documents,
)
return _DefaultVSDocumentCodec(
content_field=final_content_field,
ignore_invalid_documents=ignore_invalid_documents,
)
Loading

0 comments on commit 2e66453

Please sign in to comment.