Skip to content

Commit

Permalink
checkin
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Oct 18, 2024
1 parent a32505e commit f057907
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 54 deletions.
6 changes: 4 additions & 2 deletions py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ def supported_providers(self) -> list[str]:
class DatabaseProvider(Provider):
def __init__(self, config: DatabaseConfig):
logger.info(f"Initializing DatabaseProvider with config {config}.")

self.handle: Any = None # TODO - Type this properly, we later use it as a PostgresDBHandle

self.handle: Any = (
None # TODO - Type this properly, we later use it as a PostgresHandle
)
super().__init__(config)

@abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,12 @@ async def update_collection(

@telemetry_event("DeleteCollection")
async def delete_collection(self, collection_id: UUID) -> bool:
await self.providers.database.handle.delete_collection_relational(collection_id)
await self.providers.database.handle.delete_collection_vector(collection_id)
await self.providers.database.handle.delete_collection_relational(
collection_id
)
await self.providers.database.handle.delete_collection_vector(
collection_id
)
return True

@telemetry_event("ListCollections")
Expand Down
8 changes: 2 additions & 6 deletions py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def get_current_active_user(

async def register(self, email: str, password: str) -> Dict[str, str]:
# Create new user and give them a default collection
new_user = await self.db_provider.handle.create_user(
email, password
)
new_user = await self.db_provider.handle.create_user(email, password)
default_collection = (
await self.db_provider.handle.create_default_collection(
new_user.id,
Expand Down Expand Up @@ -163,9 +161,7 @@ async def register(self, email: str, password: str) -> Dict[str, str]:
await self.db_provider.handle.store_verification_code(
new_user.id, None, None
)
await self.db_provider.handle.mark_user_as_verified(
new_user.id
)
await self.db_provider.handle.mark_user_as_verified(new_user.id)

return new_user

Expand Down
3 changes: 3 additions & 0 deletions py/core/providers/database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def execute_query(
):
raise NotImplementedError("Subclasses must implement this method")

async def execute_many(self, query, params=None, batch_size=1000):
raise NotImplementedError("Subclasses must implement this method")

def fetch_query(
self,
query: Union[str, TextClause],
Expand Down
6 changes: 3 additions & 3 deletions py/core/providers/database/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = logging.getLogger()


class PostgresDBHandle(
class PostgresHandle(
DocumentMixin,
CollectionMixin,
BlacklistedTokensMixin,
Expand Down Expand Up @@ -44,7 +44,7 @@ def _get_table_name(self, base_name: str) -> str:
return f"{self.project_name}.{base_name}"

async def initialize(self, pool: asyncpg.pool.Pool):
logger.info("Initializing `PostgresDBHandle` with connection pool.")
logger.info("Initializing `PostgresHandle` with connection pool.")

self.pool = pool

Expand All @@ -58,7 +58,7 @@ async def initialize(self, pool: asyncpg.pool.Pool):

await self.initialize_vector_db()

logger.info("Successfully initialized `PostgresDBHandle`")
logger.info("Successfully initialized `PostgresHandle`")

async def close(self):
if self.pool:
Expand Down
6 changes: 3 additions & 3 deletions py/core/providers/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
VectorQuantizationType,
)

from .handle import PostgresDBHandle
from .handle import PostgresHandle

logger = logging.getLogger()

Expand Down Expand Up @@ -142,7 +142,7 @@ def __init__(
config.default_collection_description
)

self.handle: Optional[PostgresDBHandle] = None
self.handle: Optional[PostgresHandle] = None

def _get_table_name(self, base_name: str) -> str:
return f"{self.project_name}.{base_name}"
Expand All @@ -153,7 +153,7 @@ async def initialize(self):
)
await shared_pool.initialize()

handle = PostgresDBHandle(
handle = PostgresHandle(
self.config,
connection_string=self.connection_string,
project_name=self.project_name,
Expand Down
55 changes: 23 additions & 32 deletions py/core/providers/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
VectorTableName,
)

from .base import DatabaseMixin, QueryBuilder
from .base import DatabaseMixin
from .vecs.exc import ArgError

logger = logging.getLogger()
Expand All @@ -30,15 +30,18 @@ def index_measure_to_ops(


class VectorDBMixin(DatabaseMixin):
COLUMN_NAME = "vecs"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.project_name = kwargs.get("project_name")
self.dimension = kwargs.get("dimension")
self.quantization_type = kwargs.get("quantization_type")

async def initialize_vector_db(self):
# Create the vector table if it doesn't exist
query = f"""
CREATE TABLE IF NOT EXISTS {self.project_name}.vectors (
CREATE TABLE IF NOT EXISTS {self.project_name}.{VectorDBMixin.COLUMN_NAME} (
extraction_id TEXT PRIMARY KEY,
document_id TEXT,
user_id TEXT,
Expand All @@ -47,16 +50,16 @@ async def initialize_vector_db(self):
text TEXT,
metadata JSONB
);
CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self.project_name}.vectors (document_id);
CREATE INDEX IF NOT EXISTS idx_vectors_user_id ON {self.project_name}.vectors (user_id);
CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self.project_name}.vectors USING GIN (collection_ids);
CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self.project_name}.vectors USING GIN (to_tsvector('english', text));
CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self.project_name}.{VectorDBMixin.COLUMN_NAME} (document_id);
CREATE INDEX IF NOT EXISTS idx_vectors_user_id ON {self.project_name}.{VectorDBMixin.COLUMN_NAME} (user_id);
CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self.project_name}.{VectorDBMixin.COLUMN_NAME} USING GIN (collection_ids);
CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self.project_name}.{VectorDBMixin.COLUMN_NAME} USING GIN (to_tsvector('english', text));
"""
await self.execute_query(query)

async def upsert(self, entry: VectorEntry) -> None:
query = f"""
INSERT INTO {self.project_name}.vectors
INSERT INTO {self.project_name}.{VectorDBMixin.COLUMN_NAME}
(extraction_id, document_id, user_id, collection_ids, vector, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (extraction_id) DO UPDATE
Expand All @@ -77,7 +80,7 @@ async def upsert(self, entry: VectorEntry) -> None:

async def upsert_entries(self, entries: list[VectorEntry]) -> None:
query = f"""
INSERT INTO {self.project_name}.vectors
INSERT INTO {self.project_name}.{VectorDBMixin.COLUMN_NAME}
(extraction_id, document_id, user_id, collection_ids, vector, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (extraction_id) DO UPDATE
Expand All @@ -103,7 +106,7 @@ async def semantic_search(
query = f"""
SELECT extraction_id, document_id, user_id, collection_ids, text,
1 - (vector <=> $1::vector) as similarity, metadata
FROM {self.project_name}.vectors
FROM {self.project_name}.{VectorDBMixin.COLUMN_NAME}
WHERE collection_ids && $2
ORDER BY similarity DESC
LIMIT $3 OFFSET $4;
Expand Down Expand Up @@ -138,7 +141,7 @@ async def full_text_search(
SELECT extraction_id, document_id, user_id, collection_ids, text,
ts_rank_cd(to_tsvector('english', text), plainto_tsquery('english', $1)) as rank,
metadata
FROM {self.project_name}.vectors
FROM {self.project_name}.{VectorDBMixin.COLUMN_NAME}
WHERE collection_ids && $2 AND to_tsvector('english', text) @@ plainto_tsquery('english', $1)
ORDER BY rank DESC
LIMIT $3 OFFSET $4;
Expand Down Expand Up @@ -299,7 +302,7 @@ async def delete(

where_clause = " AND ".join(conditions)
query = f"""
DELETE FROM {self.project_name}.vectors
DELETE FROM {self.project_name}.{VectorDBMixin.COLUMN_NAME}
WHERE {where_clause}
RETURNING extraction_id;
"""
Expand All @@ -313,7 +316,7 @@ async def assign_document_to_collection_vector(
self, document_id: str, collection_id: str
) -> None:
query = f"""
UPDATE {self.project_name}.vectors
UPDATE {self.project_name}.{VectorDBMixin.COLUMN_NAME}
SET collection_ids = array_append(collection_ids, $1)
WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
"""
Expand All @@ -323,22 +326,22 @@ async def remove_document_from_collection_vector(
self, document_id: str, collection_id: str
) -> None:
query = f"""
UPDATE {self.project_name}.vectors
UPDATE {self.project_name}.{VectorDBMixin.COLUMN_NAME}
SET collection_ids = array_remove(collection_ids, $1)
WHERE document_id = $2;
"""
await self.execute_query(query, (collection_id, document_id))

async def delete_user_vector(self, user_id: str) -> None:
query = f"""
DELETE FROM {self.project_name}.vectors
DELETE FROM {self.project_name}.{VectorDBMixin.COLUMN_NAME}
WHERE user_id = $1;
"""
await self.execute_query(query, (user_id,))

async def delete_collection_vector(self, collection_id: str) -> None:
query = f"""
DELETE FROM {self.project_name}.vectors
DELETE FROM {self.project_name}.{VectorDBMixin.COLUMN_NAME}
WHERE $1 = ANY(collection_ids);
"""
await self.execute_query(query, (collection_id,))
Expand All @@ -356,7 +359,7 @@ async def get_document_chunks(
query = f"""
SELECT extraction_id, document_id, user_id, collection_ids, text, metadata
{vector_select}
FROM {self.project_name}.vectors
FROM {self.project_name}.{VectorDBMixin.COLUMN_NAME}
WHERE document_id = $1
OFFSET $2
{limit_clause};
Expand Down Expand Up @@ -428,17 +431,13 @@ async def create_index(
"""

if table_name == VectorTableName.CHUNKS:
table_name = f"{self.client.project_name}.{self.table.name}"
table_name = f"{self.project_name}.{self.table.name}"
col_name = "vec"
elif table_name == VectorTableName.ENTITIES:
table_name = (
f"{self.client.project_name}.{VectorTableName.ENTITIES}"
)
table_name = f"{self.project_name}.{VectorTableName.ENTITIES}"
col_name = "description_embedding"
elif table_name == VectorTableName.COMMUNITIES:
table_name = (
f"{self.client.project_name}.{VectorTableName.COMMUNITIES}"
)
table_name = f"{self.project_name}.{VectorTableName.COMMUNITIES}"
col_name = "embedding"
else:
raise ArgError("invalid table name")
Expand Down Expand Up @@ -471,15 +470,7 @@ async def create_index(
)

if method == IndexMethod.auto:
if self.client._supports_hnsw():
method = IndexMethod.hnsw
else:
method = IndexMethod.ivfflat

if method == IndexMethod.hnsw and not self.client._supports_hnsw():
raise ArgError(
"HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support"
)
method = IndexMethod.hnsw

ops = index_measure_to_ops(
measure, quantization_type=self.quantization_type
Expand Down
6 changes: 0 additions & 6 deletions py/core/providers/kg/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,6 @@ async def add_entities(
Returns:
result: asyncpg.Record: result of the upsert operation
"""
for entity in entities:
if entity.description_embedding is not None:
entity.description_embedding = str(
entity.description_embedding
)

return await self._add_objects(entities, table_name)

async def add_triples(
Expand Down

0 comments on commit f057907

Please sign in to comment.