Skip to content

Commit

Permalink
Update poetry env with 1bn personas, test nbs
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronWChen committed Oct 18, 2024
1 parent 1296b24 commit 33ac65e
Show file tree
Hide file tree
Showing 8 changed files with 31,359 additions and 57 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,5 @@ models/


/data

/secrets
57 changes: 42 additions & 15 deletions main_example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
from fastapi import FastAPI, HTTPException, Path, Query
from schemas_example import GenreURLChoices, BandBase, BandCreate, BandWithID
from fastapi import FastAPI, HTTPException, Path, Query, Depends
from models_example import GenreURLChoices, BandBase, BandCreate, Band, Album
from sqlmodel import Session, select
from typing import Annotated
from contextlib import asynccontextmanager
from db_example import init_db, get_session


@asynccontextmanager
async def lifespan(app: FastAPI):
init_db()
yield


# set --port argument, can't use 8000, the default uvicorn
# use localhost:{port} in browser
# use localhost:{port}/docs to look at the interactive, automatically created documentation

app = FastAPI()
app = FastAPI(lifespan=lifespan)


BANDS = [
Expand All @@ -28,8 +38,9 @@ async def bands(
genre: GenreURLChoices | None = None,
# has_albums: bool = False,
name_query: Annotated[str | None, Query(max_length=10)] = None,
) -> list[BandWithID]:
band_list = [BandWithID(**b) for b in BANDS]
session: Session = Depends(get_session),
) -> list[Band]:
band_list = session.exec(select(Band)).all()

if genre:
band_list = [b for b in band_list if b.genre.value.lower() == genre.value]
Expand All @@ -44,8 +55,11 @@ async def bands(


@app.get("/bands/{band_id}")
async def band(band_id: Annotated[int, Path(title="The band ID")]) -> BandWithID:
band = next((BandWithID(**b) for b in BANDS if b["id"] == band_id), None)
async def band(
band_id: Annotated[int, Path(title="The band ID")],
session: Session = Depends(get_session),
) -> Band:
band = session.get(Band, band_id)
# Aaron: I'm a little confused, could we use `get` instead?

if band is None:
Expand All @@ -54,15 +68,28 @@ async def band(band_id: Annotated[int, Path(title="The band ID")]) -> BandWithID
return band


@app.get("/bands/genre/{genre}")
async def bands_for_genre(genre: GenreURLChoices) -> list[dict]:
# originally this allowed any string to be used as an input and a list comprehension was used to find the value. However, this could result in server/computer waste, so we refactored to use a custom class that was restricted to a known set of genres
return [b for b in BANDS if b["genre"].lower() == genre.value]
# @app.get("/bands/genre/{genre}")
# async def bands_for_genre(genre: GenreURLChoices) -> list[dict]:
# # originally this allowed any string to be used as an input and a list comprehension was used to find the value. However, this could result in server/computer waste, so we refactored to use a custom class that was restricted to a known set of genres
# return [b for b in BANDS if b["genre"].lower() == genre.value]


@app.post("/bands")
async def create_band(band_data: BandCreate) -> BandWithID:
id = BANDS[-1]["id"] + 1
band = BandWithID(id=id, **band_data.model_dump()).model_dump()
BANDS.append(band)
async def create_band(
band_data: BandCreate, session: Session = Depends(get_session)
) -> Band:

band = Band(name=band_data.name, genre=band_data.genre)
session.add(band)

if band_data.albums:
for album in band_data.albums:
album_obj = Album(
title=album.title, release_date=album.release_date, band=band
)
session.add(album_obj)

session.commit()
session.refresh(band)

return band
51 changes: 51 additions & 0 deletions models_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# BugBytes instructor likes to use the schemas.py file and add all the Pydantic schemas to that file
from datetime import date
from enum import Enum
from pydantic import BaseModel, field_validator
from sqlmodel import SQLModel, Field, Relationship


class GenreURLChoices(Enum):
ROCK = "rock"
ELECTRONIC = "electronic"
METAL = "metal"
HIP_HOP = "hip-hop"
SHOEGAZE = "shoegaze"


class GenreChoices(Enum):
ROCK = "Rock"
ELECTRONIC = "Electronic"
METAL = "Metal"
HIP_HOP = "Hip-Hop"
SHOEGAZE = "Shoegaze"


class AlbumBase(SQLModel):
title: str
release_date: date
band_id: int | None = Field(default=None, foreign_key="band.id")


class Album(AlbumBase, table=True):
id: int = Field(default=None, primary_key=True)
band: "Band" = Relationship(back_populates="albums")


class BandBase(SQLModel):
name: str
genre: GenreChoices


class BandCreate(BandBase):
albums: list[AlbumBase] | None = None

# only pass because it is strictly inheriting and not adding other fields
@field_validator("genre", mode="before")
def title_case_genre(cls, value):
return value.title()


class Band(BandBase, table=True):
id: int = Field(default=None, primary_key=True)
albums: list[Album] = Relationship(back_populates="band")
Loading

0 comments on commit 33ac65e

Please sign in to comment.