Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Oct 18, 2024
1 parent f057907 commit 681d276
Show file tree
Hide file tree
Showing 13 changed files with 396 additions and 191 deletions.
1 change: 1 addition & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"HybridSearchSettings",
# User abstractions
"Token",
"TokenData",
Expand Down
1 change: 1 addition & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
"HybridSearchSettings",
# KG abstractions
"KGCreationSettings",
"KGEnrichmentSettings",
Expand Down
43 changes: 42 additions & 1 deletion py/core/providers/database/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,48 @@
from typing import Any, Optional, Sequence, Union
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union

import asyncpg
from sqlalchemy import TextClause, text

if TYPE_CHECKING:
from core.providers.database.handle import PostgresHandle


logger = logging.getLogger()


class SemaphoreConnectionPool:
def __init__(self, connection_string, postgres_configuration_settings):
self.connection_string = connection_string
self.postgres_configuration_settings = postgres_configuration_settings

async def initialize(self):
try:
self.semaphore = asyncio.Semaphore(
int(self.postgres_configuration_settings.max_connections * 0.9)
)

self.pool = await asyncpg.create_pool(
self.connection_string,
max_size=self.postgres_configuration_settings.max_connections,
)

logger.info(
"Successfully connected to Postgres database and created connection pool."
)
except Exception as e:
raise ValueError(
f"Error {e} occurred while attempting to connect to relational database."
) from e

@asynccontextmanager
async def get_connection(self):
async with self.semaphore:
async with self.pool.acquire() as conn:
yield conn


class QueryBuilder:
def __init__(self, table_name: str):
Expand Down
4 changes: 2 additions & 2 deletions py/core/providers/database/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ async def create_collection(
]

try:
async with self.pool.acquire() as conn: # type: ignore
async with self.pool.get_connection() as conn: # type: ignore
row = await conn.fetchrow(query, *params)

if not row:
Expand Down Expand Up @@ -176,7 +176,7 @@ async def update_collection(
)

async def delete_collection_relational(self, collection_id: UUID) -> None:
async with self.pool.acquire() as conn: # type: ignore
async with self.pool.get_connection() as conn: # type: ignore
async with conn.transaction():
try:
# Remove collection_id from users
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def upsert_documents_overview(
retries = 0
while retries < max_retries:
try:
async with self.pool.acquire() as conn: # type: ignore
async with self.pool.get_connection() as conn: # type: ignore
async with conn.transaction():
# Lock the row for update
check_query = f"""
Expand Down
48 changes: 32 additions & 16 deletions py/core/providers/database/handle.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import logging
from typing import Optional

import asyncpg

from core.base import CryptoProvider, DatabaseConfig
from core.providers.database.base import DatabaseMixin
from core.providers.database.base import (
DatabaseMixin,
SemaphoreConnectionPool,
logger,
)
from core.providers.database.collection import CollectionMixin
from core.providers.database.document import DocumentMixin
from core.providers.database.tokens import BlacklistedTokensMixin
from core.providers.database.user import UserMixin
from core.providers.database.vector import VectorDBMixin
from shared.abstractions.vector import VectorQuantizationType

logger = logging.getLogger()


class PostgresHandle(
DocumentMixin,
Expand All @@ -30,42 +29,52 @@ def __init__(
project_name: str,
dimension: int,
quantization_type: Optional[VectorQuantizationType] = None,
pool_size: int = 10,
max_retries: int = 3,
retry_delay: int = 1,
):
super().__init__(config)
self.config = config
self.connection_string = connection_string
self.crypto_provider = crypto_provider
self.project_name = project_name
self.dimension = dimension
self.quantization_type = quantization_type
self.pool = None
self.pool_size = pool_size
self.max_retries = max_retries
self.retry_delay = retry_delay
self.pool: Optional[SemaphoreConnectionPool] = None

def _get_table_name(self, base_name: str) -> str:
return f"{self.project_name}.{base_name}"

async def initialize(self, pool: asyncpg.pool.Pool):
logger.info("Initializing `PostgresHandle` with connection pool.")

async def initialize(self, pool: SemaphoreConnectionPool):
logger.info("Initializing `PostgresDBHandle`.")
self.pool = pool

async with self.pool.get_connection() as conn:
await conn.execute(f'CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")

# Create schema if it doesn't exist
await conn.execute(
f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
)

# Call create_table for each mixin
for base_class in self.__class__.__bases__:
if issubclass(base_class, DatabaseMixin):
await base_class.create_table(self)

await self.initialize_vector_db()

logger.info("Successfully initialized `PostgresHandle`")
logger.info("Successfully initialized `PostgresDBHandle`")

async def close(self):
if self.pool:
await self.pool.close()

async def execute_query(self, query, params=None, isolation_level=None):
async with self.pool.acquire() as conn:
async with self.pool.get_connection() as conn:
if isolation_level:
async with conn.transaction(isolation=isolation_level):
if params:
Expand Down Expand Up @@ -104,3 +113,10 @@ async def fetchrow_query(self, query, params=None):
return await conn.fetchrow(query, *params)
else:
return await conn.fetchrow(query)

async def __aenter__(self):
await self.initialize()
return self

async def __aexit__(self, exc_type, exc, tb):
await self.close()
45 changes: 6 additions & 39 deletions py/core/providers/database/postgres.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# TODO: Clean this up and make it more congruent across the vector database and the relational database.
import asyncio
import logging
import os
import warnings
from contextlib import asynccontextmanager
from typing import Any, Optional

import asyncpg

from core.base import (
CryptoProvider,
DatabaseConfig,
Expand All @@ -16,6 +12,7 @@
VectorQuantizationType,
)

from .base import SemaphoreConnectionPool
from .handle import PostgresHandle

logger = logging.getLogger()
Expand All @@ -30,37 +27,6 @@ def get_env_var(new_var, old_var, config_value):
return value


class SemaphoreConnectionPool(asyncpg.Pool):
def __init__(self, connection_string, postgres_configuration_settings):
self.connection_string = connection_string
self.postgres_configuration_settings = postgres_configuration_settings

async def initialize(self):
try:
self.semaphore = asyncio.Semaphore(
int(self.postgres_configuration_settings.max_connections * 0.9)
)

self.pool = await asyncpg.create_pool(
self.connection_string,
max_size=self.postgres_configuration_settings.max_connections,
)

logger.info(
"Successfully connected to Postgres database and created connection pool."
)
except Exception as e:
raise ValueError(
f"Error {e} occurred while attempting to connect to relational database."
) from e

@asynccontextmanager
async def get_connection(self):
async with self.semaphore:
async with self.pool.acquire() as conn:
yield conn


class PostgresDBProvider(DatabaseProvider):
user: str
password: str
Expand Down Expand Up @@ -148,20 +114,21 @@ def _get_table_name(self, base_name: str) -> str:
return f"{self.project_name}.{base_name}"

async def initialize(self):
shared_pool = SemaphoreConnectionPool(
pool = SemaphoreConnectionPool(
self.connection_string, self.postgres_configuration_settings
)
await shared_pool.initialize()
await pool.initialize()

handle = PostgresHandle(
self.config,
connection_string=self.connection_string,
crypto_provider=self.crypto_provider,
project_name=self.project_name,
dimension=self.vector_db_dimension,
quantization_type=self.vector_db_quantization_type,
)
await handle.initialize(shared_pool)

await handle.initialize(pool)
self.pool = pool
self.handle = handle

def _get_postgres_configuration_settings(
Expand Down
13 changes: 0 additions & 13 deletions py/core/providers/database/vecs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
from . import exc
from .client import Client
from .vector_collection import VectorCollection

__project__ = "vecs"
__version__ = "0.4.2"


__all__ = [
"VectorCollection",
"Client",
"exc",
]


def create_client(connection_string: str, *args, **kwargs) -> Client:
"""Creates a client from a Postgres connection string"""
return Client(connection_string, *args, **kwargs)
Loading

0 comments on commit 681d276

Please sign in to comment.