Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async #1

Merged
merged 23 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bff2aaa
➕ UPDATE dependencies
agn-7 Nov 28, 2023
f046c30
⚡️ 🗃️ feat(database): switch to asyncpg as PostgreSQL interface
agn-7 Nov 29, 2023
5c17846
🗃️ feat(database): migrate to async SQLAlchemy engine
agn-7 Nov 29, 2023
23190c9
⚡️ feat(database): transition to asynchronous SQLAlchemy operations
agn-7 Nov 29, 2023
13c0e73
⚡️ optimized the query using SQL Join
agn-7 Nov 29, 2023
3ac0a2a
🗃️ feat(models): optimize interaction-message relationship loading
agn-7 Nov 29, 2023
8afc586
⚡️ feat(crud): convert create_interaction to async and add interactio…
agn-7 Nov 29, 2023
0e095f5
⚡️ changed the query from selectin to joinedload
agn-7 Nov 29, 2023
9a9658e
🩹 simplified the query using selectin within the relationship itself …
agn-7 Nov 29, 2023
ae493c2
🦺 update settings schema to support all models
agn-7 Nov 29, 2023
e935318
🐛 fixed model typo name
agn-7 Nov 29, 2023
2089145
⚡️ made async
agn-7 Nov 29, 2023
ce3d184
⚡️ Switch to async database operations
agn-7 Nov 29, 2023
ca65ce5
⚡️ update methods to use SQLAlchemy Core and async
agn-7 Nov 30, 2023
d17e405
🏷️ add hinting type
agn-7 Dec 1, 2023
119e591
➕ UPDATE dependencies
agn-7 Dec 1, 2023
2bd4e2b
🔨 update tests to async
agn-7 Dec 1, 2023
8ece2d9
🔨 update to async version
agn-7 Dec 1, 2023
0cd62ab
🐛 fixed postgresql async url
agn-7 Dec 1, 2023
c5490f6
👷 add codecov policy
agn-7 Dec 1, 2023
9c52bfe
📝 update SQLALCHEMY_DATABASE_URI to use asyncpg
agn-7 Dec 1, 2023
5424607
♻️ refactor pytest fixture
agn-7 Dec 1, 2023
3db4423
♻️ refactor to use pytest_asyncio for the fixture
agn-7 Dec 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
coverage:
status:
project:
default:
# basic
target: auto
threshold: null
4 changes: 2 additions & 2 deletions .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ jobs:
run: |
poetry run pytest -s
env:
SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db
SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db
- name: Code Coverage
run: |
poetry run pytest --cov=./ifsguid --cov-report=xml --cov-report=term-missing
env:
SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db
SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
env:
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,5 @@ To launch an API instance, you should:
You can also run the project via `docker-compose` (i.e. `docker compose up -d`) on port `80` in which you would need the [.docker.env](/.docker.env) containing the following variable to create the database:

```
SQLALCHEMY_DATABASE_URI=postgresql://<username>:<password>@ifsguid_db/<db-name>
SQLALCHEMY_DATABASE_URI=postgresql+asyncpg://<username>:<password>@ifsguid_db/<db-name>
```


4 changes: 3 additions & 1 deletion alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def run_migrations_online() -> None:

"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = f"{settings.SQLALCHEMY_DATABASE_URI}"
configuration[
"sqlalchemy.url"
] = settings.SQLALCHEMY_DATABASE_URI.unicode_string().replace("+asyncpg", "")
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
Expand Down
2 changes: 1 addition & 1 deletion ifsguid/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def assemble_db_connection(cls, v: Optional[str], values: ValidationInfo) -> Any
return v
print("Creating SQLALCHEMY_DATABASE_URI from .env file ...")
return PostgresDsn.build(
scheme="postgresql",
scheme="postgresql+asyncpg",
username=values.data.get("POSTGRES_USER"),
password=values.data.get("POSTGRES_PASSWORD"),
host=values.data.get("POSTGRES_SERVER"),
Expand Down
81 changes: 46 additions & 35 deletions ifsguid/crud.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,90 @@
from typing import List
from uuid import UUID

from sqlalchemy.orm import Session
from sqlalchemy import delete, update
from sqlalchemy.orm import joinedload
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession

from . import models, schemas, utils
from . import models, schemas


def get_interactions(
db: Session, page: int = None, per_page: int = 10
async def get_interactions(
db: AsyncSession, page: int = None, per_page: int = 10
) -> List[models.Interaction]:
query = db.query(models.Interaction)
stmt = select(models.Interaction).options(joinedload(models.Interaction.messages))

if page is not None:
query = query.offset((page - 1) * per_page).limit(per_page)
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

return query.all()
result = await db.execute(stmt)
return result.scalars().unique().all()


def get_interaction(db: Session, id: UUID) -> models.Interaction:
return db.query(models.Interaction).filter(models.Interaction.id == id).first()
async def get_interaction(db: AsyncSession, id: UUID) -> models.Interaction:
stmt = select(models.Interaction).where(models.Interaction.id == id)
result = await db.execute(stmt)
return result.scalar()


def create_interaction(db: Session, settings: schemas.Settings) -> models.Interaction:
async def create_interaction(
db: AsyncSession, settings: schemas.Settings
) -> models.Interaction:
interaction = models.Interaction(
settings=settings.model_dump(),
)
db.add(interaction)
db.commit()
await db.commit()
await db.refresh(interaction)

return interaction


def delete_interaction(db: Session, id: UUID) -> None:
interaction = (
db.query(models.Interaction).filter(models.Interaction.id == id).first()
)
async def delete_interaction(db: AsyncSession, id: UUID) -> None:
stmt = delete(models.Interaction).where(models.Interaction.id == id)
result = await db.execute(stmt)

if interaction is not None:
db.delete(interaction)
db.commit()
if result.rowcount:
await db.commit()
return True

return False


def update_interaction(
db: Session, id: UUID, settings: schemas.Settings
async def update_interaction(
db: AsyncSession, id: UUID, settings: schemas.Settings
) -> models.Interaction:
interaction: models.Interaction = (
db.query(models.Interaction).filter(models.Interaction.id == id).first()
stmt = (
update(models.Interaction)
.where(models.Interaction.id == id)
.values(settings=settings)
)
result = await db.execute(stmt)

if interaction is not None:
interaction.settings = settings
db.commit()
return interaction
if result.rowcount:
await db.commit()
return True

return None


def get_messages(
db: Session, interaction_id: UUID = None, page: int = None, per_page: int = 10
async def get_messages(
db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10
) -> List[models.Message]:
query = db.query(models.Message)
stmt = select(models.Message)

if interaction_id is not None:
query = query.filter(models.Message.interaction_id == interaction_id)
stmt = stmt.where(models.Message.interaction_id == interaction_id)

if page is not None:
query = query.offset((page - 1) * per_page).limit(per_page)
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

return query.all()
result = await db.execute(stmt)
return result.scalars().all()


def create_message(
db: Session, messages: List[schemas.MessageCreate], interaction_id: UUID
async def create_message(
db: AsyncSession, messages: List[schemas.MessageCreate], interaction_id: UUID
) -> List[models.Message]:
messages_db = []
for msg in messages:
Expand All @@ -84,5 +95,5 @@ def create_message(
db.add(message)
messages_db.append(message)

db.commit()
await db.commit()
return messages_db
15 changes: 11 additions & 4 deletions ifsguid/database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.orm import sessionmaker

from .config import settings

engine = create_engine(
settings.SQLALCHEMY_DATABASE_URI.unicode_string(), pool_pre_ping=True
engine: AsyncEngine = create_async_engine(
settings.SQLALCHEMY_DATABASE_URI.unicode_string()
)
async_session = sessionmaker(
autocommit=False,
autoflush=False,
future=True,
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True, bind=engine)
75 changes: 39 additions & 36 deletions ifsguid/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,42 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from g4f import ModelUtils

from . import crud, schemas, modules
from .database import SessionLocal
from .database import async_session, AsyncSession

router = APIRouter()


def get_db() -> SessionLocal:
try:
db = SessionLocal()
yield db
finally:
db.close()
async def get_db() -> AsyncSession:
async with async_session() as session:
yield session


@router.get("/", response_model=str)
async def get_root(db: Session = Depends(get_db)) -> str:
async def get_root(db: AsyncSession = Depends(get_db)) -> str:
return "Hello from IFSGuid!"


@router.get("/interactions", response_model=List[schemas.Interaction])
async def get_all_interactions(
page: Optional[int] = None,
per_page: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> List[schemas.Interaction]:
interactions = await crud.get_interactions(db=db, page=page, per_page=per_page)

return [
schemas.Interaction.model_validate(interaction)
for interaction in crud.get_interactions(db=db, page=page, per_page=per_page)
schemas.Interaction.model_validate(interaction) for interaction in interactions
]


@router.get(
"/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False
)
async def get_interactions(
id: UUID, db: Session = Depends(get_db)
id: UUID, db: AsyncSession = Depends(get_db)
) -> schemas.Interaction:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
Expand All @@ -48,15 +46,20 @@ async def get_interactions(

@router.post("/interactions", response_model=schemas.Interaction)
async def create_interactions(
settings: schemas.Settings, db: Session = Depends(get_db)
prompt: schemas.Prompt,
chat_model: schemas.ChatModel = Depends(),
db: AsyncSession = Depends(get_db),
) -> schemas.Interaction:
return schemas.Interaction.model_validate(
crud.create_interaction(db=db, settings=settings)
settings = schemas.Settings(
model=chat_model.model, prompt=prompt.prompt, role=prompt.role
)
interaction = await crud.create_interaction(db=db, settings=settings)

return schemas.Interaction.model_validate(interaction)


@router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False)
async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None:
async def delete_interaction(id: UUID, db: AsyncSession = Depends(get_db)) -> None:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
)
Expand All @@ -66,7 +69,7 @@ async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None:
"/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False
)
async def update_interaction(
id: UUID, settings: schemas.Settings, db: Session = Depends(get_db)
id: UUID, settings: schemas.Settings, db: AsyncSession = Depends(get_db)
) -> schemas.Interaction:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
Expand All @@ -80,30 +83,31 @@ async def get_all_message_in_interaction(
interaction_id: UUID,
page: Optional[int] = None,
per_page: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> List[schemas.Message]:
interaction = crud.get_interaction(db=db, id=str(interaction_id))
interaction = await crud.get_interaction(db=db, id=str(interaction_id))

if not interaction:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Interaction not found"
)

return [
schemas.Message.model_validate(message)
for message in crud.get_messages(
db=db, interaction_id=str(interaction_id), page=page, per_page=per_page
)
]
messages = await crud.get_messages(
db=db, interaction_id=str(interaction_id), page=page, per_page=per_page
)

return [schemas.Message.model_validate(message) for message in messages]


@router.post(
"/interactions/{interactions_id}/messages", response_model=List[schemas.Message]
)
async def create_message(
interaction_id: UUID, message: schemas.MessageCreate, db: Session = Depends(get_db)
interaction_id: UUID,
message: schemas.MessageCreate,
db: AsyncSession = Depends(get_db),
) -> schemas.Message:
interaction = crud.get_interaction(db=db, id=str(interaction_id))
interaction = await crud.get_interaction(db=db, id=str(interaction_id))

if not interaction:
raise HTTPException(
Expand All @@ -114,17 +118,16 @@ async def create_message(

messages = []
if message.role == "human":
ai_content = modules.generate_ai_response(
content=message.content, model=interaction.settings.model_name
ai_content = await modules.generate_ai_response(
content=message.content,
model=ModelUtils.convert[interaction.settings.model],
)
ai_message = schemas.MessageCreate(role="ai", content=ai_content)

messages.append(message)
messages.append(ai_message)

return [
schemas.Message.model_validate(message)
for message in crud.create_message(
db=db, messages=messages, interaction_id=str(interaction_id)
)
]
messages = await crud.create_message(
db=db, messages=messages, interaction_id=str(interaction_id)
)
return [schemas.Message.model_validate(message) for message in messages]
2 changes: 1 addition & 1 deletion ifsguid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Interaction(Base):
)
settings = Column(JSON)

messages = relationship("Message", back_populates="interaction")
messages = relationship("Message", back_populates="interaction", lazy="selectin")


class Message(Base):
Expand Down
20 changes: 13 additions & 7 deletions ifsguid/modules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import g4f
import traceback

g4f.debug.logging = True
g4f.check_version = False


def generate_ai_response(
content: str, model: str = "GPT4"
) -> str: # TODO: use async version for future
return g4f.ChatCompletion.create(
model=g4f.models.gpt_4 if model == "GPT4" else g4f.models.gpt_35_turbo,
messages=[{"role": "human", "content": content}],
)
async def generate_ai_response(
content: str, model: g4f.Model = g4f.models.default
) -> str:
try:
response = await g4f.ChatCompletion.create_async(
model=model,
messages=[{"role": "human", "content": content}],
)
return response
except Exception:
traceback.print_exc()
return "error!"
Loading