Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanrfrazier committed Apr 22, 2024
1 parent a9023dc commit 40c2c96
Showing 1 changed file with 42 additions and 74 deletions.
116 changes: 42 additions & 74 deletions libs/astradb/langchain_astradb/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def __init__(
collection_vector_service_options: specifies the use of server-side
embeddings within Astra DB. Only one of `embedding` or
`collection_vector_service_options` can be provided.
NOTE: this feature is currently in beta.
NOTE: This feature is under current development.
Note:
For concurrency in synchronous :meth:`~add_texts`:, as a rule of thumb, on a
Expand Down Expand Up @@ -326,16 +327,17 @@ async def _aget_embedding_dimension(self) -> int:
return self.embedding_dimension

@property
def embeddings(self) -> Embeddings:
if self.collection_vector_service_options is not None:
raise ValueError(
"Server-side embeddings are in use, no client-side embeddings\
available."
)

assert self.embedding is not None
def embeddings(self) -> Optional[Embeddings]:
"""
Accesses the supplied embeddings object. If using server-side embeddings,
this will return None.
"""
return self.embedding

def _using_vectorize(self) -> bool:
"""Indicates whether server-side embeddings are being used."""
return self._using_vectorize()

def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The underlying API calls already returns a "score proper",
Expand Down Expand Up @@ -624,18 +626,16 @@ def add_texts(
)
self.astra_env.ensure_db_setup()

if self.collection_vector_service_options is not None:
# using server-side embeddings
if self._using_vectorize():
documents_to_insert = self._get_vectorize_documents_to_insert(
texts, metadatas, ids
)
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vectors = self.embedding.embed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
else:
raise ValueError("No embeddings or vectorization service available.")

def _handle_batch(document_batch: List[DocDict]) -> List[str]:
# self.collection is not None (by _ensure_astra_db_client)
Expand Down Expand Up @@ -728,18 +728,17 @@ async def aadd_texts(
)
await self.astra_env.aensure_db_setup()

if self.collection_vector_service_options is not None:
if self._using_vectorize():
# using server-side embeddings
documents_to_insert = self._get_vectorize_documents_to_insert(
texts, metadatas, ids
)
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vectors = await self.embedding.aembed_documents(list(texts))
documents_to_insert = self._get_documents_to_insert(
texts, embedding_vectors, metadatas, ids
)
else:
raise ValueError("No embeddings or vectorization service available.")

async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
# self.async_collection is not None here for sure
Expand Down Expand Up @@ -959,24 +958,21 @@ def similarity_search_with_score_id(
Returns:
The list of (Document, score, id), the most similar to the query.
"""
if self.collection_vector_service_options is not None:

if self._using_vectorize():
return self._similarity_search_with_score_id_with_vectorize(
query=query,
k=k,
filter=filter,
)
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or\
`collection_vector_service_options`"
)

async def asimilarity_search_with_score_id(
self,
Expand All @@ -994,24 +990,20 @@ async def asimilarity_search_with_score_id(
Returns:
The list of (Document, score, id), the most similar to the query.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
return await self._asimilarity_search_with_score_id_with_vectorize(
query=query,
k=k,
filter=filter,
)
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or\
`collection_vector_service_options`"
)

def similarity_search_with_score_by_vector(
self,
Expand Down Expand Up @@ -1084,7 +1076,7 @@ def similarity_search(
Returns:
The list of Documents most similar to the query.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
return [
doc
for (doc, _, _) in self._similarity_search_with_score_id_with_vectorize(
Expand All @@ -1093,18 +1085,14 @@ def similarity_search(
filter=filter,
)
]
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or \
`collection_vector_service_options`"
)

async def asimilarity_search(
self,
Expand All @@ -1123,7 +1111,7 @@ async def asimilarity_search(
Returns:
The list of Documents most similar to the query.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
return [
doc
for (
Expand All @@ -1136,18 +1124,14 @@ async def asimilarity_search(
filter=filter,
)
]
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or \
`collection_vector_service_options`"
)

def similarity_search_by_vector(
self,
Expand Down Expand Up @@ -1217,7 +1201,7 @@ def similarity_search_with_score(
Returns:
The list of (Document, score), the most similar to the query vector.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
return [
(doc, score)
for (
Expand All @@ -1230,18 +1214,14 @@ def similarity_search_with_score(
filter=filter,
)
]
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or \
`collection_vector_service_options`"
)

async def asimilarity_search_with_score(
self,
Expand All @@ -1259,7 +1239,7 @@ async def asimilarity_search_with_score(
Returns:
The list of (Document, score), the most similar to the query vector.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
return [
(doc, score)
for (
Expand All @@ -1272,18 +1252,14 @@ async def asimilarity_search_with_score(
filter=filter,
)
]
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.asimilarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or \
`collection_vector_service_options`"
)

@staticmethod
def _get_mmr_hits(
Expand Down Expand Up @@ -1426,9 +1402,10 @@ def max_marginal_relevance_search(
Returns:
The list of Documents selected by maximal marginal relevance.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
raise ValueError("MMR search is unsupported for server-side embeddings.")
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = self.embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding_vector,
Expand All @@ -1437,11 +1414,6 @@ def max_marginal_relevance_search(
lambda_mult=lambda_mult,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or \
`collection_vector_service_options`"
)

async def amax_marginal_relevance_search(
self,
Expand Down Expand Up @@ -1469,9 +1441,10 @@ async def amax_marginal_relevance_search(
Returns:
The list of Documents selected by maximal marginal relevance.
"""
if self.collection_vector_service_options is not None:
if self._using_vectorize():
raise ValueError("MMR search is unsupported for server-side embeddings.")
elif self.embedding is not None:
else:
assert self.embedding is not None
embedding_vector = await self.embedding.aembed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
embedding_vector,
Expand All @@ -1480,11 +1453,6 @@ async def amax_marginal_relevance_search(
lambda_mult=lambda_mult,
filter=filter,
)
else:
raise ValueError(
"expected one of `embedding` or \
`collection_vector_service_options`"
)

@classmethod
def _from_kwargs(
Expand Down

0 comments on commit 40c2c96

Please sign in to comment.