Skip to content

Commit

Permalink
Merge pull request #1 from agn-7/async
Browse files Browse the repository at this point in the history
  • Loading branch information
agn-7 authored Dec 1, 2023
2 parents ece7dca + 3db4423 commit 48bf621
Show file tree
Hide file tree
Showing 17 changed files with 283 additions and 138 deletions.
7 changes: 7 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
coverage:
status:
project:
default:
# basic
target: auto
threshold: null
4 changes: 2 additions & 2 deletions .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ jobs:
run: |
poetry run pytest -s
env:
SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db
SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db
- name: Code Coverage
run: |
poetry run pytest --cov=./ifsguid --cov-report=xml --cov-report=term-missing
env:
SQLALCHEMY_DATABASE_URI: postgresql://ifsguid_usr:root@localhost/ifsguid_db
SQLALCHEMY_DATABASE_URI: postgresql+asyncpg://ifsguid_usr:root@localhost/ifsguid_db
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
env:
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,5 @@ To launch an API instance, you should:
You can also run the project via `docker-compose` (i.e. `docker compose up -d`) on port `80` in which you would need the [.docker.env](/.docker.env) containing the following variable to create the database:

```
SQLALCHEMY_DATABASE_URI=postgresql://<username>:<password>@ifsguid_db/<db-name>
SQLALCHEMY_DATABASE_URI=postgresql+asyncpg://<username>:<password>@ifsguid_db/<db-name>
```


4 changes: 3 additions & 1 deletion alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def run_migrations_online() -> None:
"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = f"{settings.SQLALCHEMY_DATABASE_URI}"
configuration[
"sqlalchemy.url"
] = settings.SQLALCHEMY_DATABASE_URI.unicode_string().replace("+asyncpg", "")
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
Expand Down
2 changes: 1 addition & 1 deletion ifsguid/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def assemble_db_connection(cls, v: Optional[str], values: ValidationInfo) -> Any
return v
print("Creating SQLALCHEMY_DATABASE_URI from .env file ...")
return PostgresDsn.build(
scheme="postgresql",
scheme="postgresql+asyncpg",
username=values.data.get("POSTGRES_USER"),
password=values.data.get("POSTGRES_PASSWORD"),
host=values.data.get("POSTGRES_SERVER"),
Expand Down
81 changes: 46 additions & 35 deletions ifsguid/crud.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,90 @@
from typing import List
from uuid import UUID

from sqlalchemy.orm import Session
from sqlalchemy import delete, update
from sqlalchemy.orm import joinedload
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession

from . import models, schemas, utils
from . import models, schemas


def get_interactions(
db: Session, page: int = None, per_page: int = 10
async def get_interactions(
db: AsyncSession, page: int = None, per_page: int = 10
) -> List[models.Interaction]:
query = db.query(models.Interaction)
stmt = select(models.Interaction).options(joinedload(models.Interaction.messages))

if page is not None:
query = query.offset((page - 1) * per_page).limit(per_page)
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

return query.all()
result = await db.execute(stmt)
return result.scalars().unique().all()


def get_interaction(db: Session, id: UUID) -> models.Interaction:
return db.query(models.Interaction).filter(models.Interaction.id == id).first()
async def get_interaction(db: AsyncSession, id: UUID) -> models.Interaction:
stmt = select(models.Interaction).where(models.Interaction.id == id)
result = await db.execute(stmt)
return result.scalar()


def create_interaction(db: Session, settings: schemas.Settings) -> models.Interaction:
async def create_interaction(
db: AsyncSession, settings: schemas.Settings
) -> models.Interaction:
interaction = models.Interaction(
settings=settings.model_dump(),
)
db.add(interaction)
db.commit()
await db.commit()
await db.refresh(interaction)

return interaction


def delete_interaction(db: Session, id: UUID) -> None:
interaction = (
db.query(models.Interaction).filter(models.Interaction.id == id).first()
)
async def delete_interaction(db: AsyncSession, id: UUID) -> None:
stmt = delete(models.Interaction).where(models.Interaction.id == id)
result = await db.execute(stmt)

if interaction is not None:
db.delete(interaction)
db.commit()
if result.rowcount:
await db.commit()
return True

return False


def update_interaction(
db: Session, id: UUID, settings: schemas.Settings
async def update_interaction(
db: AsyncSession, id: UUID, settings: schemas.Settings
) -> models.Interaction:
interaction: models.Interaction = (
db.query(models.Interaction).filter(models.Interaction.id == id).first()
stmt = (
update(models.Interaction)
.where(models.Interaction.id == id)
.values(settings=settings)
)
result = await db.execute(stmt)

if interaction is not None:
interaction.settings = settings
db.commit()
return interaction
if result.rowcount:
await db.commit()
return True

return None


def get_messages(
db: Session, interaction_id: UUID = None, page: int = None, per_page: int = 10
async def get_messages(
db: AsyncSession, interaction_id: UUID = None, page: int = None, per_page: int = 10
) -> List[models.Message]:
query = db.query(models.Message)
stmt = select(models.Message)

if interaction_id is not None:
query = query.filter(models.Message.interaction_id == interaction_id)
stmt = stmt.where(models.Message.interaction_id == interaction_id)

if page is not None:
query = query.offset((page - 1) * per_page).limit(per_page)
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

return query.all()
result = await db.execute(stmt)
return result.scalars().all()


def create_message(
db: Session, messages: List[schemas.MessageCreate], interaction_id: UUID
async def create_message(
db: AsyncSession, messages: List[schemas.MessageCreate], interaction_id: UUID
) -> List[models.Message]:
messages_db = []
for msg in messages:
Expand All @@ -84,5 +95,5 @@ def create_message(
db.add(message)
messages_db.append(message)

db.commit()
await db.commit()
return messages_db
15 changes: 11 additions & 4 deletions ifsguid/database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.orm import sessionmaker

from .config import settings

engine = create_engine(
settings.SQLALCHEMY_DATABASE_URI.unicode_string(), pool_pre_ping=True
engine: AsyncEngine = create_async_engine(
settings.SQLALCHEMY_DATABASE_URI.unicode_string()
)
async_session = sessionmaker(
autocommit=False,
autoflush=False,
future=True,
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True, bind=engine)
75 changes: 39 additions & 36 deletions ifsguid/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,42 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from g4f import ModelUtils

from . import crud, schemas, modules
from .database import SessionLocal
from .database import async_session, AsyncSession

router = APIRouter()


def get_db() -> SessionLocal:
try:
db = SessionLocal()
yield db
finally:
db.close()
async def get_db() -> AsyncSession:
async with async_session() as session:
yield session


@router.get("/", response_model=str)
async def get_root(db: Session = Depends(get_db)) -> str:
async def get_root(db: AsyncSession = Depends(get_db)) -> str:
return "Hello from IFSGuid!"


@router.get("/interactions", response_model=List[schemas.Interaction])
async def get_all_interactions(
page: Optional[int] = None,
per_page: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> List[schemas.Interaction]:
interactions = await crud.get_interactions(db=db, page=page, per_page=per_page)

return [
schemas.Interaction.model_validate(interaction)
for interaction in crud.get_interactions(db=db, page=page, per_page=per_page)
schemas.Interaction.model_validate(interaction) for interaction in interactions
]


@router.get(
"/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False
)
async def get_interactions(
id: UUID, db: Session = Depends(get_db)
id: UUID, db: AsyncSession = Depends(get_db)
) -> schemas.Interaction:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
Expand All @@ -48,15 +46,20 @@ async def get_interactions(

@router.post("/interactions", response_model=schemas.Interaction)
async def create_interactions(
settings: schemas.Settings, db: Session = Depends(get_db)
prompt: schemas.Prompt,
chat_model: schemas.ChatModel = Depends(),
db: AsyncSession = Depends(get_db),
) -> schemas.Interaction:
return schemas.Interaction.model_validate(
crud.create_interaction(db=db, settings=settings)
settings = schemas.Settings(
model=chat_model.model, prompt=prompt.prompt, role=prompt.role
)
interaction = await crud.create_interaction(db=db, settings=settings)

return schemas.Interaction.model_validate(interaction)


@router.delete("/interactions", response_model=Dict[str, Any], include_in_schema=False)
async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None:
async def delete_interaction(id: UUID, db: AsyncSession = Depends(get_db)) -> None:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
)
Expand All @@ -66,7 +69,7 @@ async def delete_interaction(id: UUID, db: Session = Depends(get_db)) -> None:
"/interactions/{id}", response_model=schemas.Interaction, include_in_schema=False
)
async def update_interaction(
id: UUID, settings: schemas.Settings, db: Session = Depends(get_db)
id: UUID, settings: schemas.Settings, db: AsyncSession = Depends(get_db)
) -> schemas.Interaction:
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="NotImplementedError"
Expand All @@ -80,30 +83,31 @@ async def get_all_message_in_interaction(
interaction_id: UUID,
page: Optional[int] = None,
per_page: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_db),
) -> List[schemas.Message]:
interaction = crud.get_interaction(db=db, id=str(interaction_id))
interaction = await crud.get_interaction(db=db, id=str(interaction_id))

if not interaction:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Interaction not found"
)

return [
schemas.Message.model_validate(message)
for message in crud.get_messages(
db=db, interaction_id=str(interaction_id), page=page, per_page=per_page
)
]
messages = await crud.get_messages(
db=db, interaction_id=str(interaction_id), page=page, per_page=per_page
)

return [schemas.Message.model_validate(message) for message in messages]


@router.post(
"/interactions/{interactions_id}/messages", response_model=List[schemas.Message]
)
async def create_message(
interaction_id: UUID, message: schemas.MessageCreate, db: Session = Depends(get_db)
interaction_id: UUID,
message: schemas.MessageCreate,
db: AsyncSession = Depends(get_db),
) -> schemas.Message:
interaction = crud.get_interaction(db=db, id=str(interaction_id))
interaction = await crud.get_interaction(db=db, id=str(interaction_id))

if not interaction:
raise HTTPException(
Expand All @@ -114,17 +118,16 @@ async def create_message(

messages = []
if message.role == "human":
ai_content = modules.generate_ai_response(
content=message.content, model=interaction.settings.model_name
ai_content = await modules.generate_ai_response(
content=message.content,
model=ModelUtils.convert[interaction.settings.model],
)
ai_message = schemas.MessageCreate(role="ai", content=ai_content)

messages.append(message)
messages.append(ai_message)

return [
schemas.Message.model_validate(message)
for message in crud.create_message(
db=db, messages=messages, interaction_id=str(interaction_id)
)
]
messages = await crud.create_message(
db=db, messages=messages, interaction_id=str(interaction_id)
)
return [schemas.Message.model_validate(message) for message in messages]
2 changes: 1 addition & 1 deletion ifsguid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class Interaction(Base):
)
settings = Column(JSON)

messages = relationship("Message", back_populates="interaction")
messages = relationship("Message", back_populates="interaction", lazy="selectin")


class Message(Base):
Expand Down
20 changes: 13 additions & 7 deletions ifsguid/modules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import g4f
import traceback

g4f.debug.logging = True
g4f.check_version = False


def generate_ai_response(
content: str, model: str = "GPT4"
) -> str: # TODO: use async version for future
return g4f.ChatCompletion.create(
model=g4f.models.gpt_4 if model == "GPT4" else g4f.models.gpt_35_turbo,
messages=[{"role": "human", "content": content}],
)
async def generate_ai_response(
content: str, model: g4f.Model = g4f.models.default
) -> str:
try:
response = await g4f.ChatCompletion.create_async(
model=model,
messages=[{"role": "human", "content": content}],
)
return response
except Exception:
traceback.print_exc()
return "error!"
Loading

0 comments on commit 48bf621

Please sign in to comment.