diff --git a/app/main.py b/app/main.py index 5c71d8d..b9fb90e 100644 --- a/app/main.py +++ b/app/main.py @@ -4,15 +4,10 @@ from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles -from . import templates from .env import DEBUG from .routers import order, placements, products, stat -from .store import ( - PlacedItemTable, - PlacementTable, - ProductTable, - startup_and_shutdown_db, -) +from .store import startup_and_shutdown_db +from .templates import macro_template # https://stackoverflow.com/a/65270864 @@ -29,24 +24,17 @@ async def lifespan(_: FastAPI): app.mount("/static", StaticFiles(directory="static"), name="static") -app.include_router(products.router) -app.include_router(order.router) -app.include_router(placements.router) -app.include_router(stat.router) + +@macro_template("index.html") +def tmp_index(): ... @app.get("/", response_class=HTMLResponse) async def get_root(request: Request): - return HTMLResponse(templates.index(request)) + return HTMLResponse(tmp_index(request)) -if DEBUG: - - @app.get("/test") - async def test(): - return { - "product_table": await ProductTable.select_all(), - "order_sessions": order.order_sessions, - "placed_item_table": await PlacedItemTable.select_all(), - "placement_table": await PlacementTable.select_all(), - } +app.include_router(products.router) +app.include_router(order.router) +app.include_router(placements.router) +app.include_router(stat.router) diff --git a/app/routers/order.py b/app/routers/order.py index 517bdb1..feeca59 100644 --- a/app/routers/order.py +++ b/app/routers/order.py @@ -12,13 +12,76 @@ status, ) from fastapi.responses import HTMLResponse +import pydantic -from .. import templates -from ..store import PlacedItemTable, PlacementTable, ProductTable -from ..store.product import OrderSession +from ..store import PlacedItemTable, PlacementTable, Product, ProductTable +from ..templates import hx_post as tmp_hx_post +from ..templates import macro_template router = APIRouter() + +class OrderSession(pydantic.BaseModel): + class CountedProduct(pydantic.BaseModel): + name: str + price: str + count: int = pydantic.Field(default=1) + + items: dict[UUID, Product] = pydantic.Field(default_factory=dict) + counted_products: dict[int, CountedProduct] = pydantic.Field(default_factory=dict) + total_count: int = pydantic.Field(default=0) + total_price: int = pydantic.Field(default=0) + + def clear(self): + self.total_count = 0 + self.total_price = 0 + self.items = {} + self.counted_products = {} + + def total_price_str(self) -> str: + return Product.to_price_str(self.total_price) + + def add(self, p: Product): + self.total_count += 1 + self.total_price += p.price + self.items[uuid4()] = p + if p.product_id in self.counted_products: + self.counted_products[p.product_id].count += 1 + else: + counted_product = self.CountedProduct(name=p.name, price=p.price_str()) + self.counted_products[p.product_id] = counted_product + + def delete(self, item_id: UUID): + if item_id in self.items: + self.total_count -= 1 + product = self.items.pop(item_id) + self.total_price -= product.price + if self.counted_products[product.product_id].count == 1: + self.counted_products.pop(product.product_id) + else: + self.counted_products[product.product_id].count -= 1 + + +@macro_template("order.html") +def tmp_order(products: list[Product], session: OrderSession): ... + + +@macro_template("order.html", "order_session") +def tmp_session(session: OrderSession): ... + + +@macro_template("order.html", "confirm_modal") +def tmp_confirm_modal(session: OrderSession): ... + + +@macro_template("order.html", "issued_modal") +def tmp_issued_modal(placement_id: int, session: OrderSession): ... + + +@macro_template("order.html", "error_modal") +def tmp_error_modal(message: str): ... + + # NOTE: Do NOT store this data in database (the data is transient and should be kept in memory) order_sessions: dict[UUID, OrderSession] = {} SESSION_COOKIE_KEY = "session_key" @@ -39,24 +102,22 @@ async def instruct_creation_of_new_session_or_get_existing_session( ): if session_key is None or (session := order_sessions.get(session_key)) is None: return HTMLResponse( - templates.hx_post(request, "/order"), + tmp_hx_post(request, "/order"), status_code=status.HTTP_405_METHOD_NOT_ALLOWED, headers={"allow": "POST"}, ) products = await ProductTable.select_all() - return HTMLResponse(templates.order.page(request, products, session)) + return HTMLResponse(tmp_order(request, products, session)) -@router.get("/order/confirm", response_class=HTMLResponse) +@router.get("/order/confirm-modal", response_class=HTMLResponse) async def get_confirm_dialog(request: Request, session: SessionDeps): if session.total_count == 0: - error_status = "エラー:商品が選択されていません" + error_msg = "商品が選択されていません" + return HTMLResponse(tmp_error_modal(request, error_msg)) else: - error_status = None - return HTMLResponse( - templates.components.order_confirm(request, session, error_status) - ) + return HTMLResponse(tmp_confirm_modal(request, session)) @router.post("/order") @@ -73,10 +134,8 @@ async def create_new_session_or_place_order( return res if session.total_count == 0: - error_status = "エラー:商品が選択されていません" - return HTMLResponse( - templates.components.order_issued(request, None, session, error_status) - ) + error_msg = "商品が選択されていません" + return HTMLResponse(tmp_error_modal(request, error_msg)) order_sessions.pop(session_key) res = await _place_order(request, session) @@ -86,8 +145,7 @@ async def create_new_session_or_place_order( def _create_new_session() -> UUID: session_key = uuid4() - session = OrderSession() - order_sessions[session_key] = session + order_sessions[session_key] = OrderSession() return session_key @@ -96,9 +154,7 @@ async def _place_order(request: Request, session: SessionDeps) -> HTMLResponse: placement_id = await PlacedItemTable.issue(product_ids) # TODO: add a branch for out of stock error await PlacementTable.insert(placement_id) - return HTMLResponse( - templates.components.order_issued(request, placement_id, session, None) - ) + return HTMLResponse(tmp_issued_modal(request, placement_id, session)) @router.post("/order/items") @@ -109,19 +165,19 @@ async def add_order_item( raise HTTPException(status_code=404, detail=f"Product {product_id} not found") session.add(product) - return HTMLResponse(templates.order.session(request, session)) + return HTMLResponse(tmp_session(request, session)) @router.delete("/order/items/{item_id}", response_class=HTMLResponse) async def delete_order_item(request: Request, session: SessionDeps, item_id: UUID): session.delete(item_id) - return HTMLResponse(templates.order.session(request, session)) + return HTMLResponse(tmp_session(request, session)) @router.delete("/order/items") async def clear_order_items(request: Request, session: SessionDeps) -> Response: session.clear() - return HTMLResponse(templates.order.session(request, session)) + return HTMLResponse(tmp_session(request, session)) # TODO: add proper path operation for order deferral @@ -141,7 +197,7 @@ async def clear_order_items(request: Request, session: SessionDeps) -> Response: # # TODO: respond with a message about the success of the deferral action # # message = "注文を保留しました" # # res = HTMLResponse( -# # templates.order.session(request, OrderSession(), message=message) +# # tmp_session(request, OrderSession(), message=message) # # ) # # res.delete_cookie(SESSION_COOKIE_KEY) # # return res diff --git a/app/routers/placements.py b/app/routers/placements.py index 66a6db4..3e85531 100644 --- a/app/routers/placements.py +++ b/app/routers/placements.py @@ -1,29 +1,346 @@ import asyncio -from typing import Annotated, AsyncGenerator, Literal +from datetime import datetime +from functools import partial +from typing import Annotated, Any, AsyncGenerator, Awaitable, Callable, Literal, Mapping +import sqlalchemy +import sqlmodel from fastapi import APIRouter, Form, Header, HTTPException, Request, status from fastapi.responses import HTMLResponse +from sqlmodel import col from sse_starlette.sse import EventSourceResponse -from .. import templates from ..store import ( + PlacedItem, + Placement, PlacementTable, + Product, database, - load_incoming_placements, - load_one_resolved_placement, - load_placed_items_incoming, - load_resolved_placements, supply_all_and_complete, supply_and_complete_placement_if_done, + unixepoch, ) +from ..store.placement import ModifiedFlag +from ..templates import macro_template router = APIRouter() +def _to_time(unix_epoch: int) -> str: + return datetime.fromtimestamp(unix_epoch).strftime("%H:%M:%S") + + +async def _agen_query_executor[T]( + query: str, + unique_key: Literal["placement_id"] | Literal["product_id"], + init_cb: Callable[[Any, Mapping], None], + elem_cb: Callable[[Mapping], T], + list_cb: Callable[[list[T]], None], +): + prev_unique_id = -1 + lst: list[T] = list() + async for map in database.iterate(query): + if (unique_id := map[unique_key]) != prev_unique_id: + if prev_unique_id != -1: + list_cb(lst) + prev_unique_id = unique_id + init_cb(unique_id, map) + lst: list[T] = list() + lst.append(elem_cb(map)) + if prev_unique_id != -1: + list_cb(lst) + + +query_placed_items_incoming: sqlalchemy.Select = ( + sqlmodel.select(PlacedItem.placement_id, PlacedItem.product_id) + .add_columns(sqlmodel.func.count(col(PlacedItem.product_id)).label("count")) + .where(col(PlacedItem.supplied_at).is_(None)) # Filter out supplied items + .group_by(col(PlacedItem.placement_id), col(PlacedItem.product_id)) + .select_from(sqlmodel.join(PlacedItem, Product)) + .add_columns(col(Product.name), col(Product.filename)) + .join(Placement) + .add_columns(unixepoch(col(Placement.placed_at))) + .where(col(Placement.canceled_at).is_(None) & col(Placement.completed_at).is_(None)) + .order_by(col(PlacedItem.product_id).asc(), col(PlacedItem.placement_id).asc()) +) + +type placed_item_t = dict[str, int | str | list[dict[str, int | str]]] + + +def _placed_items_loader() -> Callable[[], Awaitable[list[placed_item_t]]]: + query_str = str(query_placed_items_incoming.compile()) + + placed_items: list[placed_item_t] = [] + + def init_cb(product_id: int, map: Mapping): + placed_items.append( + {"product_id": product_id, "name": map["name"], "filename": map["filename"]} + ) + + def elem_cb(map: Mapping) -> dict[str, int | str]: + return { + "placement_id": map["placement_id"], + "count": map["count"], + "placed_at": _to_time(map["placed_at"]), + } + + def list_cb(placements: list[dict[str, int | str]]): + placed_items[-1]["placements"] = placements + + load_placed_products = partial( + _agen_query_executor, query_str, "product_id", init_cb, elem_cb, list_cb + ) + + async def load(): + placed_items.clear() + await load_placed_products() + return placed_items + + return load + + +load_placed_items_incoming = _placed_items_loader() + + +class placed_items_incoming: # namespace + @macro_template("placed-items-incoming.html") + @staticmethod + def page(placed_items: list[placed_item_t]): ... + + @macro_template("placed-items-incoming.html", "component") + @staticmethod + def component(placed_items: list[placed_item_t]): ... + + @macro_template("placed-items-incoming.html", "component_with_sound") + @staticmethod + def component_with_sound(placed_items: list[placed_item_t]): ... + + +type item_t = dict[str, int | str | None] +type placement_t = dict[str, int | list[item_t] | str | datetime | None] + + +query_incoming: sqlalchemy.Select = ( + # Query from the placements table + sqlmodel.select(Placement.placement_id) + .group_by(col(Placement.placement_id)) + .order_by(col(Placement.placement_id).asc()) + .add_columns(unixepoch(col(Placement.placed_at))) + # Filter out canceled/completed placements + .where(col(Placement.canceled_at).is_(None) & col(Placement.completed_at).is_(None)) + # Query the list of placed items + .select_from(sqlmodel.join(Placement, PlacedItem)) + .add_columns(col(PlacedItem.product_id), unixepoch(col(PlacedItem.supplied_at))) + .group_by(col(PlacedItem.product_id)) + .order_by(col(PlacedItem.product_id).asc()) + .add_columns(sqlmodel.func.count(col(PlacedItem.product_id)).label("count")) + # Query product name + .join(Product) + .add_columns(col(Product.name)) +) + + +query_resolved: sqlalchemy.Select = ( + # Query from the placements table + sqlmodel.select(Placement.placement_id) + .group_by(col(Placement.placement_id)) + .order_by(col(Placement.placement_id).asc()) + .add_columns(unixepoch(col(Placement.placed_at))) + # Query canceled/completed placements + .where( + col(Placement.canceled_at).isnot(None) | col(Placement.completed_at).isnot(None) + ) + .add_columns(unixepoch(col(Placement.canceled_at))) + .add_columns(unixepoch(col(Placement.completed_at))) + # Query the list of placed items + .select_from(sqlmodel.join(Placement, PlacedItem)) + .add_columns(col(PlacedItem.product_id), unixepoch(col(PlacedItem.supplied_at))) + .group_by(col(PlacedItem.product_id)) + .order_by(col(PlacedItem.product_id).asc()) + .add_columns(sqlmodel.func.count(col(PlacedItem.product_id)).label("count")) + # Query product name and price + .join(Product) + .add_columns(col(Product.name), col(Product.price)) +) + + +def callbacks_placements_incoming( + placements: list[placement_t], +) -> tuple[ + Callable[[int, Mapping], None], + Callable[[Mapping], item_t], + Callable[[list[item_t]], None], +]: + def init_cb(placement_id: int, map: Mapping) -> None: + placements.append( + { + "placement_id": placement_id, + "placed_at": _to_time(map["placed_at"]), + } + ) + + def elem_cb(map: Mapping) -> item_t: + supplied_at = map["supplied_at"] + return { + "product_id": map["product_id"], + "count": map["count"], + "name": map["name"], + "supplied_at": _to_time(supplied_at) if supplied_at else None, + } + + def list_cb(items: list[item_t]) -> None: + placements[-1]["items_"] = items + + return init_cb, elem_cb, list_cb + + +def callbacks_placements_resolved( + placements: list[placement_t], +) -> tuple[ + Callable[[int, Mapping], None], + Callable[[Mapping], item_t], + Callable[[list[item_t]], None], +]: + total_price = 0 + + def init_cb(placement_id: int, map: Mapping) -> None: + canceled_at, completed_at = map["canceled_at"], map["completed_at"] + placements.append( + { + "placement_id": placement_id, + "placed_at": _to_time(map["placed_at"]), + "canceled_at": _to_time(canceled_at) if canceled_at else None, + "completed_at": _to_time(completed_at) if completed_at else None, + } + ) + nonlocal total_price + total_price = 0 + + def elem_cb(map: Mapping) -> item_t: + count, price = map["count"], map["price"] + nonlocal total_price + total_price += count * price + supplied_at = map["supplied_at"] + return { + "product_id": map["product_id"], + "count": count, + "name": map["name"], + "price": Product.to_price_str(price), + "supplied_at": _to_time(supplied_at) if supplied_at else None, + } + + def list_cb(items: list[item_t]) -> None: + placements[-1]["items_"] = items + placements[-1]["total_price"] = Product.to_price_str(total_price) + + return init_cb, elem_cb, list_cb + + +def _placements_loader( + query: sqlalchemy.Compiled, + callbacks: Callable[ + [list[placement_t]], + tuple[ + Callable[[int, Mapping], None], + Callable[[Mapping], item_t], + Callable[[list[item_t]], None], + ], + ], +) -> Callable[[], Awaitable[list[placement_t]]]: + placements: list[placement_t] = [] + + init_cb, elem_cb, list_cb = callbacks(placements) + load_placements = partial( + _agen_query_executor, str(query), "placement_id", init_cb, elem_cb, list_cb + ) + + async def load(): + placements.clear() + await load_placements() + return placements + + return load + + +load_incoming_placements = _placements_loader( + query_incoming.compile(), callbacks_placements_incoming +) +load_resolved_placements = _placements_loader( + query_resolved.compile(), callbacks_placements_resolved +) + + +async def load_one_resolved_placement(placement_id: int) -> placement_t | None: + query = query_resolved.where(col(Placement.placement_id) == placement_id) + + rows_agen = database.iterate(query) + if (row := await anext(rows_agen, None)) is None: + return None + + canceled_at, completed_at = row["canceled_at"], row["completed_at"] + placement: placement_t = { + "placement_id": placement_id, + "placed_at": _to_time(row["placed_at"]), + "canceled_at": _to_time(canceled_at) if canceled_at else None, + "completed_at": _to_time(completed_at) if completed_at else None, + } + + total_price = 0 + + def to_item(row: Mapping) -> item_t: + count, price = row["count"], row["price"] + nonlocal total_price + total_price += count * price + supplied_at = row["supplied_at"] + return { + "product_id": row["product_id"], + "count": count, + "name": row["name"], + "price": Product.to_price_str(price), + "supplied_at": _to_time(supplied_at) if supplied_at else None, + } + + items = [to_item(row)] + async for row in rows_agen: + items.append(to_item(row)) + placement["items_"] = items + placement["total_price"] = Product.to_price_str(total_price) + + return placement + + +class incoming_placements: # namespace + @macro_template("incoming-placements.html") + @staticmethod + def page(placements: list[placement_t]): ... + + @macro_template("incoming-placements.html", "component") + @staticmethod + def component(placements: list[placement_t]): ... + + @macro_template("incoming-placements.html", "component_with_sound") + @staticmethod + def component_with_sound(placements: list[placement_t]): ... + + +class resolved_placements: # namespace + @macro_template("resolved-placements.html") + @staticmethod + def page(placements: list[placement_t]): ... + + @macro_template("resolved-placements.html", "completed") + @staticmethod + def completed(placement: placement_t): ... + + @macro_template("resolved-placements.html", "canceled") + @staticmethod + def canceled(placement: placement_t): ... + + @router.get("/placed-items/incoming", response_class=HTMLResponse) async def get_incoming_placed_items(request: Request): placed_items = await load_placed_items_incoming() - return HTMLResponse(templates.placed_items_incoming.page(request, placed_items)) + return HTMLResponse(placed_items_incoming.page(request, placed_items)) @router.get("/placed-items/incoming-stream", response_class=EventSourceResponse) @@ -36,17 +353,18 @@ async def placed_items_incoming_stream( async def _placed_items_incoming_stream(request: Request): placed_items = await load_placed_items_incoming() - content = templates.placed_items_incoming.component(request, placed_items) + content = placed_items_incoming.component(request, placed_items) yield dict(data=content) try: while True: - async with PlacementTable.modified: - await PlacementTable.modified.wait() + async with PlacementTable.modified_cond_flag: + flag = await PlacementTable.modified_cond_flag.wait() + if flag & (ModifiedFlag.INCOMING | ModifiedFlag.PUT_BACK): + template = placed_items_incoming.component_with_sound + else: + template = placed_items_incoming.component placed_items = await load_placed_items_incoming() - content = templates.placed_items_incoming.component( - request, placed_items - ) - yield dict(data=content) + yield dict(data=template(request, placed_items)) except asyncio.CancelledError: yield dict(event="shutdown", data="") finally: @@ -61,7 +379,7 @@ async def supply_products(placement_id: int, product_id: int): @router.get("/placements/incoming", response_class=HTMLResponse) async def get_incoming_placements(request: Request): placements = await load_incoming_placements() - return HTMLResponse(templates.incoming_placements.page(request, placements)) + return HTMLResponse(incoming_placements.page(request, placements)) @router.get("/placements/incoming-stream", response_class=EventSourceResponse) @@ -76,15 +394,18 @@ async def _incoming_placements_stream( request: Request, ) -> AsyncGenerator[dict[str, str], None]: placements = await load_incoming_placements() - content = templates.incoming_placements.component(request, placements) + content = incoming_placements.component(request, placements) yield dict(data=content) try: while True: - async with PlacementTable.modified: - await PlacementTable.modified.wait() + async with PlacementTable.modified_cond_flag: + flag = await PlacementTable.modified_cond_flag.wait() + if flag & (ModifiedFlag.INCOMING | ModifiedFlag.PUT_BACK): + template = incoming_placements.component_with_sound + else: + template = incoming_placements.component placements = await load_incoming_placements() - content = templates.incoming_placements.component(request, placements) - yield dict(data=content) + yield dict(data=template(request, placements)) except asyncio.CancelledError: yield dict(event="shutdown", data="") finally: @@ -94,7 +415,7 @@ async def _incoming_placements_stream( @router.get("/placements/resolved", response_class=HTMLResponse) async def get_resolved_placements(request: Request): placements = await load_resolved_placements() - return HTMLResponse(templates.resolved_placements.page(request, placements)) + return HTMLResponse(resolved_placements.page(request, placements)) @router.delete("/placements/{placement_id}/resolved-at") @@ -118,7 +439,7 @@ async def complete( detail = f"Placement {placement_id} not found" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail) - return HTMLResponse(templates.resolved_placements.completed(request, placement)) + return HTMLResponse(resolved_placements.completed(request, placement)) @router.post("/placements/{placement_id}/canceled-at", response_class=HTMLResponse) @@ -137,4 +458,4 @@ async def cancel( detail = f"Placement {placement_id} not found" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail) - return HTMLResponse(templates.resolved_placements.canceled(request, placement)) + return HTMLResponse(resolved_placements.canceled(request, placement)) diff --git a/app/routers/products.py b/app/routers/products.py index a81639b..3062d50 100644 --- a/app/routers/products.py +++ b/app/routers/products.py @@ -3,16 +3,28 @@ from fastapi import APIRouter, Form, HTTPException, Request, Response, status from fastapi.responses import HTMLResponse -from .. import templates from ..store import Product, ProductTable, delete_product +from ..templates import macro_template router = APIRouter() +@macro_template("products.html") +def tmp_products(products: list[Product]): ... + + +@macro_template("products.html", "editor") +def tmp_editor(product: Product): ... + + +@macro_template("products.html", "empty_editor") +def tmp_empty_editor(): ... + + @router.get("/products", response_class=HTMLResponse) async def get_products(request: Request): products = await ProductTable.select_all() - return HTMLResponse(templates.products.page(request, products)) + return HTMLResponse(tmp_products(request, products)) @router.post("/products", response_class=Response) @@ -88,12 +100,12 @@ async def get_product_editor(request: Request, product_id: int): if (product := maybe_product) is None: detail = f"Product {product_id} not found" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail) - return HTMLResponse(templates.products.editor(request, product)) + return HTMLResponse(tmp_editor(request, product)) @router.get("/product-editor", response_class=HTMLResponse) async def get_empty_product_editor(request: Request): - return HTMLResponse(templates.products.empty_editor(request)) + return HTMLResponse(tmp_empty_editor(request)) # TODO: This path is defined temporally for convenience and should be removed in the future. diff --git a/app/routers/stat.py b/app/routers/stat.py index c0083e9..da3f921 100644 --- a/app/routers/stat.py +++ b/app/routers/stat.py @@ -1,43 +1,83 @@ -from fastapi import APIRouter, Request, Response +import csv +from dataclasses import dataclass +from datetime import datetime +from functools import lru_cache +from pathlib import Path +from typing import Annotated, Literal, Mapping + +import sqlalchemy +import sqlmodel +from fastapi import APIRouter, Header, Request from fastapi.responses import HTMLResponse +from sqlmodel import col -from .. import templates -from ..store import ( - PlacedItemTable, - ProductTable, - PlacementTable, - database, -) +from ..templates import macro_template +from ..store import PlacedItem, Placement, Product, database, unixepoch -import csv -import os +router = APIRouter() -from datetime import datetime, timedelta -from typing import Any, Mapping -import statistics +CSV_OUTPUT_PATH = Path("./static/stat.csv") +GRAPH_OUTPUT_PATH = Path("./static/sales.png") -router = APIRouter() -DATABASE_URL = os.path.abspath("./db/app.db") -CSV_OUTPUT_PATH = os.path.abspath("./static/stat.csv") -GRAPH_OUTPUT_PATH = os.path.abspath("./static/sales.png") +@dataclass +class Stat: + @dataclass + class SalesSummary: + product_id: int + name: str + filename: str + price: str + count: int + count_today: int + total_sales: str + total_sales_today: str + no_stock: int | None + + total_sales_all_time: str + total_sales_today: str + total_items_all_time: int + total_items_today: int + sales_summary_list: list[SalesSummary] + avg_service_time_all: str + avg_service_time_recent: str + +@macro_template("stat.html") +def tmp_stat(stat: Stat): ... -def convert_unixepoch_to_localtime(unixepoch_time): + +@macro_template("wait-estimate.html") +def tmp_wait_estimate_page(estimate: str, waiting_order_count: int): ... + + +@macro_template("wait-estimate.html", "component") +def tmp_wait_estimate_component(estimate: str, waiting_order_count: int): ... + + +def convert_unixepoch_to_localtime(unixepoch_time: int) -> str: local_time = datetime.fromtimestamp(unixepoch_time).astimezone() return local_time.strftime("%Y-%m-%d %H:%M:%S") +def zero_if_null[T](v: T | None) -> T | Literal[0]: + """ + Handles the case where aggregate functions return NULL when no matching rows + are found + """ + return v if v is not None else 0 + + +# TODO: Use async operations for writing csv rows so that this function does not block async def export_placements(): query = """ SELECT - placements.placement_id, + placements.placement_id, + placed_items.item_no, unixepoch(placements.placed_at) AS placed_at, - unixepoch(placements.completed_at) AS completed_at, - placements.canceled_at, - placed_items.product_id, - products.product_id, - products.name, + unixepoch(placements.completed_at) AS completed_at, + placed_items.product_id, + products.name, products.price FROM placements @@ -45,25 +85,23 @@ async def export_placements(): placed_items ON placements.placement_id = placed_items.placement_id INNER JOIN products ON placed_items.product_id = products.product_id - WHERE - placements.canceled_at IS NULL; + WHERE + placements.canceled_at IS NULL + ORDER BY + placements.placement_id ASC; """ - csv_file_path = os.path.abspath(CSV_OUTPUT_PATH) - with open(csv_file_path, "w", newline="") as csv_file: + with open(CSV_OUTPUT_PATH, "w", newline="") as csv_file: csv_writer = csv.writer(csv_file) async_gen = database.iterate(query) if (row := await anext(async_gen, None)) is None: return - headers = [ - key for key in dict(row).keys() if key not in ("product_id", "canceled_at") - ] + headers = [key for key in dict(row).keys()] csv_writer.writerow(headers) csv_writer.writerow(_filtered_row(row)) - async for row in async_gen: csv_writer.writerow(_filtered_row(row)) @@ -73,175 +111,163 @@ def _filtered_row(row: Mapping) -> list: for column_name, value in dict(row).items(): if column_name in ("placed_at", "completed_at") and value is not None: value = convert_unixepoch_to_localtime(value) - if column_name not in ("product_id", "canceled_at"): - filtered_row.append(value) + filtered_row.append(value) return filtered_row -async def compute_total_sales() -> tuple[int, int, int, int, list[dict[str, Any]]]: - product_table = await ProductTable.select_all() - placed_item_table = await PlacedItemTable.select_all() - placement_table = await PlacementTable.select_all() - - product_price_map = {product.product_id: product for product in product_table} - total_sales_all_time = 0 - total_sales_today = 0 - total_items_all_time = 0 - total_items_today = 0 - sales_summary_aggregated = {} +_placed_today = sqlmodel.func.date( + col(Placement.placed_at), "localtime" +) == sqlmodel.func.date("now", "localtime") +TOTAL_SALES_QUERY: sqlalchemy.Compiled = ( + sqlmodel.select(col(Product.product_id)) + .select_from(sqlmodel.join(PlacedItem, Placement)) + .join(Product) + .add_columns( + sqlmodel.func.count(col(Product.product_id)).label("count"), + sqlmodel.func.count(col(Product.product_id)) + .filter(_placed_today) + .label("count_today"), + col(Product.name), + col(Product.filename), + col(Product.price), + sqlmodel.func.sum(col(Product.price)).label("total_sales"), + sqlmodel.func.sum(col(Product.price)) + .filter(_placed_today) + .label("total_sales_today"), + col(Product.no_stock), + ) + .where(col(Placement.canceled_at).is_(None)) + .group_by(col(Product.product_id)) + .compile(compile_kwargs={"literal_binds": True}) +) - today = datetime.today().date() - for item in placed_item_table: - product_id = item.product_id - placement = next( - (p for p in placement_table if p.placement_id == item.placement_id), None +class AvgServiceTimeQuery: + @classmethod + @lru_cache(1) + def all_and_recent(cls) -> sqlalchemy.Compiled: + return ( + sqlmodel.select( + sqlmodel.func.avg(cls._service_time_diff).label("all"), + sqlmodel.func.avg(cls._last_30mins).label("recent"), + ) + .where(col(Placement.completed_at).isnot(None)) + .compile() ) - if ( - placement - and placement.canceled_at is None - and product_id in product_price_map - ): - product_info = product_price_map[product_id] - - if product_info.name not in sales_summary_aggregated: - sales_summary_aggregated[product_info.name] = { - "name": product_info.name, - "filename": product_info.filename, - "count": 1, - "total_sales": product_info.price, - "no_stock": product_info.no_stock, - } - else: - sales_summary_aggregated[product_info.name]["count"] += 1 - sales_summary_aggregated[product_info.name]["total_sales"] += ( - product_info.price - ) - - total_sales_all_time += product_info.price - total_items_all_time += 1 - - if placement.completed_at is not None: - placed_date = ( - datetime.fromisoformat(str(placement.placed_at)) - + (datetime.now().astimezone().utcoffset() or timedelta(0)) - ).date() - if placed_date == today: - total_sales_today += product_info.price - total_items_today += 1 - - sales_summary_list = list(sales_summary_aggregated.values()) + @classmethod + @lru_cache(1) + def recent(cls) -> sqlalchemy.Compiled: + return ( + sqlmodel.select(sqlmodel.func.avg(cls._last_30mins).label("recent")) + .where(col(Placement.completed_at).isnot(None)) + .compile() + ) - return ( - total_sales_all_time, - total_sales_today, - total_items_all_time, - total_items_today, - sales_summary_list, + _service_time_diff = unixepoch(col(Placement.completed_at)) - unixepoch( + col(Placement.placed_at) + ) + _elapsed_secs = sqlmodel.func.unixepoch() - unixepoch(col(Placement.completed_at)) + _last_30mins = sqlmodel.case( + (_elapsed_secs / sqlmodel.text("60") < sqlmodel.text("30"), _service_time_diff) ) + @staticmethod + def seconds_to_jpn_mmss(secs: int) -> str: + mm, ss = divmod(secs, 60) + return f"{mm} 分 {ss} 秒" -async def compute_average_service_time() -> tuple[str, str]: - placement_table = await PlacementTable.select_all() - - all_service_times = [] - recent_service_times = [] - now = datetime.now().astimezone() - offset = now.utcoffset() or timedelta(0) - thirty_minutes_ago = now - timedelta(minutes=30) - offset - - for placement in placement_table: - if placement.completed_at is not None: - placed_at = datetime.fromisoformat(str(placement.placed_at)).astimezone() - completed_at = datetime.fromisoformat( - str(placement.completed_at) - ).astimezone() - - time_diff = (completed_at - placed_at).total_seconds() - all_service_times.append(time_diff) - if completed_at >= thirty_minutes_ago: - recent_service_times.append(time_diff) +async def construct_stat() -> Stat: + sales_summary_aggregated: dict[int, Stat.SalesSummary] = {} + total_sales_all_time = 0 + total_sales_today = 0 + total_items_all_time = 0 + total_items_today = 0 - if all_service_times: - average_service_time_all_seconds = statistics.mean(all_service_times) - average_all_minutes, average_all_seconds = divmod( - int(average_service_time_all_seconds), 60 + async for row in database.iterate(str(TOTAL_SALES_QUERY)): + product_id = row["product_id"] + assert isinstance(product_id, int) + + count, count_today, total_sales, total_sales_today_ = map( + zero_if_null, + ( + row["count"], + row["count_today"], + row["total_sales"], + row["total_sales_today"], + ), ) - average_service_time_all = f"{average_all_minutes} 分 {average_all_seconds} 秒" - else: - average_service_time_all = "0 分 0 秒" - if recent_service_times: - average_service_time_recent_seconds = statistics.mean(recent_service_times) - average_recent_minutes, average_recent_seconds = divmod( - int(average_service_time_recent_seconds), 60 + sales_summary_aggregated[product_id] = Stat.SalesSummary( + product_id=product_id, + name=row["name"], + filename=row["filename"], + price=Product.to_price_str(row["price"]), + count=count, + count_today=count_today, + total_sales=Product.to_price_str(total_sales), + total_sales_today=Product.to_price_str(total_sales_today_), + no_stock=row["no_stock"], ) - average_service_time_recent = ( - f"{average_recent_minutes} 分 {average_recent_seconds} 秒" - ) - else: - average_service_time_recent = "0 分 0 秒" - return average_service_time_all, average_service_time_recent + total_sales_all_time += total_sales + total_sales_today += total_sales_today_ + total_items_all_time += count + total_items_today += count_today -async def compute_waiting_orders() -> int: - placement_table = await PlacementTable.select_all() - waiting_orders = 0 - for placement in placement_table: - if placement.completed_at is None and placement.canceled_at is None: - waiting_orders += 1 - return waiting_orders + sales_summary_list = list(sales_summary_aggregated.values()) + + record = await database.fetch_one(str(AvgServiceTimeQuery.all_and_recent())) + assert record is not None + avg_service_time_all, avg_service_time_recent = ( + AvgServiceTimeQuery.seconds_to_jpn_mmss(int(zero_if_null(record[0]))), + AvgServiceTimeQuery.seconds_to_jpn_mmss(int(zero_if_null(record[1]))), + ) + + return Stat( + total_sales_all_time=Product.to_price_str(total_sales_all_time), + total_sales_today=Product.to_price_str(total_sales_today), + total_items_all_time=total_items_all_time, + total_items_today=total_items_today, + sales_summary_list=sales_summary_list, + avg_service_time_all=avg_service_time_all, + avg_service_time_recent=avg_service_time_recent, + ) @router.get("/stat", response_class=HTMLResponse) async def get_stat(request: Request): await export_placements() - ( - total_sales_all_time, - total_sales_today, - total_items_all_time, - total_items_today, - sales_summary_list, - ) = await compute_total_sales() - ( - average_service_time_all, - average_service_time_recent, - ) = await compute_average_service_time() - return HTMLResponse( - templates.stat( - request, - total_sales_all_time, - total_sales_today, - total_items_all_time, - total_items_today, - sales_summary_list, - average_service_time_all, - average_service_time_recent, - ) - ) + return HTMLResponse(tmp_stat(request, await construct_stat())) + + +WAITING_ORDER_COUNT_QUERY: sqlalchemy.Compiled = ( + sqlmodel.select(sqlmodel.func.count(col(Placement.placement_id))) + .where(col(Placement.completed_at).is_(None) & col(Placement.canceled_at).is_(None)) + .compile() +) @router.get("/wait-estimates", response_class=HTMLResponse) -async def get_estimates(request: Request): - return HTMLResponse(templates.wait_estimates(request)) - - -@router.post("/wait-estimates/service-time", response_class=Response) -async def post_estimates(): - ( - average_service_time_all, - average_service_time_recent, - ) = await compute_average_service_time() - if average_service_time_recent == "0 分 0 秒": - average_service_time_recent = "待ち時間なし" - return average_service_time_recent - - -@router.post("/wait-estimates/waiting-orders", response_class=Response) -async def post_orders(): - waiting_orders = await compute_waiting_orders() - waiting_orders = str(waiting_orders) + "人" - return waiting_orders +async def get_estimates( + request: Request, hx_request: Annotated[str | None, Header()] = None +): + async with database.transaction(): + estimate_record = await database.fetch_one(str(AvgServiceTimeQuery.recent())) + waiting_order_count = await database.fetch_val(str(WAITING_ORDER_COUNT_QUERY)) + + assert estimate_record is not None + estimate = int(zero_if_null(estimate_record[0])) + + if estimate == 0: + estimate_str = "待ち時間なし" + else: + estimate_str = AvgServiceTimeQuery.seconds_to_jpn_mmss(estimate) + + if hx_request == "true": + template = tmp_wait_estimate_component + else: + template = tmp_wait_estimate_page + return HTMLResponse(template(request, estimate_str, waiting_order_count)) diff --git a/app/routers/test_placements.py b/app/routers/test_placements.py new file mode 100644 index 0000000..1df86ba --- /dev/null +++ b/app/routers/test_placements.py @@ -0,0 +1,59 @@ +import sqlparse +from inline_snapshot import snapshot + +from .placements import query_incoming, query_placed_items_incoming, query_resolved + + +def format_sql(sql: object): + return sqlparse.format(sql, keyword_case="upper", reindent=True, wrap_after=80) + + +def test_incoming_placed_items_query(): + assert format_sql(str(query_placed_items_incoming)) == snapshot( + """\ +SELECT placed_items.placement_id, placed_items.product_id, count(placed_items.product_id) AS COUNT, + products.name, products.filename, unixepoch(placements.placed_at) AS placed_at +FROM placed_items +JOIN products ON products.product_id = placed_items.product_id +JOIN placements ON placements.placement_id = placed_items.placement_id +WHERE placed_items.supplied_at IS NULL + AND placements.canceled_at IS NULL + AND placements.completed_at IS NULL +GROUP BY placed_items.placement_id, placed_items.product_id +ORDER BY placed_items.product_id ASC, placed_items.placement_id ASC\ +""" + ) + + +def test_incoming_placements_query(): + assert format_sql(str(query_incoming)) == snapshot( + """\ +SELECT placements.placement_id, unixepoch(placements.placed_at) AS placed_at, placed_items.product_id, + unixepoch(placed_items.supplied_at) AS supplied_at, count(placed_items.product_id) AS COUNT, products.name +FROM placements +JOIN placed_items ON placements.placement_id = placed_items.placement_id +JOIN products ON products.product_id = placed_items.product_id +WHERE placements.canceled_at IS NULL + AND placements.completed_at IS NULL +GROUP BY placements.placement_id, placed_items.product_id +ORDER BY placements.placement_id ASC, placed_items.product_id ASC\ +""" + ) + + +def test_resolved_placements_query(): + assert format_sql(str(query_resolved)) == snapshot( + """\ +SELECT placements.placement_id, unixepoch(placements.placed_at) AS placed_at, + unixepoch(placements.canceled_at) AS canceled_at, unixepoch(placements.completed_at) AS completed_at, + placed_items.product_id, unixepoch(placed_items.supplied_at) AS supplied_at, + count(placed_items.product_id) AS COUNT, products.name, products.price +FROM placements +JOIN placed_items ON placements.placement_id = placed_items.placement_id +JOIN products ON products.product_id = placed_items.product_id +WHERE placements.canceled_at IS NOT NULL + OR placements.completed_at IS NOT NULL +GROUP BY placements.placement_id, placed_items.product_id +ORDER BY placements.placement_id ASC, placed_items.product_id ASC\ +""" + ) diff --git a/app/routers/test_stat.py b/app/routers/test_stat.py new file mode 100644 index 0000000..2a080e5 --- /dev/null +++ b/app/routers/test_stat.py @@ -0,0 +1,66 @@ +import sqlparse +from inline_snapshot import snapshot + +from .stat import AvgServiceTimeQuery, WAITING_ORDER_COUNT_QUERY, TOTAL_SALES_QUERY + + +def format_sql(sql: object): + return sqlparse.format(sql, keyword_case="upper", reindent=True, wrap_after=80) + + +def test_total_sales_query(): + assert format_sql(str(TOTAL_SALES_QUERY)) == snapshot( + """\ +SELECT products.product_id, count(products.product_id) AS COUNT, + count(products.product_id) FILTER ( + WHERE date(placements.placed_at, + 'localtime') = date('now', + 'localtime')) AS count_today, products.name, products.filename, products.price, + sum(products.price) AS total_sales, + sum(products.price) FILTER ( + WHERE date(placements.placed_at, + 'localtime') = date('now', + 'localtime')) AS total_sales_today, products.no_stock +FROM placed_items +JOIN placements ON placements.placement_id = placed_items.placement_id +JOIN products ON products.product_id = placed_items.product_id +WHERE placements.canceled_at IS NULL +GROUP BY products.product_id\ +""" + ) + + +def test_avg_service_time_query_recent(): + assert format_sql(str(AvgServiceTimeQuery.recent())) == snapshot( + """\ +SELECT avg(CASE + WHEN ((unixepoch() - unixepoch(placements.completed_at)) / CAST(60 AS NUMERIC) < 30) THEN unixepoch(placements.completed_at) - unixepoch(placements.placed_at) + END) AS recent +FROM placements +WHERE placements.completed_at IS NOT NULL\ +""" + ) + + +def test_avg_service_time_query(): + assert format_sql(str(AvgServiceTimeQuery.all_and_recent())) == snapshot( + """\ +SELECT avg(unixepoch(placements.completed_at) - unixepoch(placements.placed_at)) AS "all", + avg(CASE + WHEN ((unixepoch() - unixepoch(placements.completed_at)) / CAST(60 AS NUMERIC) < 30) THEN unixepoch(placements.completed_at) - unixepoch(placements.placed_at) + END) AS recent +FROM placements +WHERE placements.completed_at IS NOT NULL\ +""" + ) + + +def test_waiting_order_count_query(): + assert format_sql(str(WAITING_ORDER_COUNT_QUERY)) == snapshot( + """\ +SELECT count(placements.placement_id) AS count_1 +FROM placements +WHERE placements.completed_at IS NULL + AND placements.canceled_at IS NULL\ +""" + ) diff --git a/app/store/__init__.py b/app/store/__init__.py index 04b87ce..72e1af9 100644 --- a/app/store/__init__.py +++ b/app/store/__init__.py @@ -1,7 +1,4 @@ from datetime import datetime, timezone -from enum import Enum, auto -from functools import partial -from typing import Any, Awaitable, Callable, Literal, Mapping import sqlalchemy import sqlmodel @@ -12,7 +9,7 @@ from . import placed_item, placement, product from ._helper import _colname from .placed_item import PlacedItem -from .placement import Placement +from .placement import ModifiedFlag, Placement # noqa: F401 from .product import Product DATABASE_URL = "sqlite:///db/app.db" @@ -33,336 +30,48 @@ async def delete_product(product_id: int): await database.execute(query) -def _to_time(unix_epoch: int) -> str: - return datetime.fromtimestamp(unix_epoch).strftime("%H:%M:%S") - - -type item_t = dict[str, int | str | None] -type placement_t = dict[str, int | list[item_t] | str | datetime | None] - - # TODO: there should be a way to use the unixepoch function without this boiler plate def unixepoch(attr: sa_orm.Mapped) -> sqlalchemy.Label: - colname = _colname(col(attr)) + colname = _colname(attr) alias = getattr(attr, "name") return sqlalchemy.literal_column(f"unixepoch({colname})").label(alias) -class PlacementsQuery(Enum): - incoming = auto() - resolved = auto() - - def placements(self) -> sqlalchemy.Select: - # Query from the placements table - query: sqlalchemy.Select = ( - sqlmodel.select(Placement.placement_id) - .group_by(col(Placement.placement_id)) - .order_by(col(Placement.placement_id).asc()) - .add_columns(unixepoch(col(Placement.placed_at))) - ) - - query = self._extra_timestamps(query) - - query = ( - # Query the list of placed items - query.select_from(sqlmodel.join(Placement, PlacedItem)) - .add_columns(col(PlacedItem.product_id)) - .add_columns(unixepoch(col(PlacedItem.supplied_at))) - .group_by(col(PlacedItem.product_id)) - .order_by(col(PlacedItem.product_id).asc()) - .add_columns(sqlmodel.func.count(col(PlacedItem.product_id)).label("count")) - # Query product information - .join(Product) - .add_columns(col(Product.name)) - ) - - # Include prices for resolved placements - if self == self.resolved: - query = query.add_columns(col(Product.price)) - - return query +async def supply_and_complete_placement_if_done(placement_id: int, product_id: int): + async with database.transaction(): + await PlacedItemTable._supply(placement_id, product_id) - def _extra_timestamps(self, query: sqlalchemy.Select) -> sqlalchemy.Select: - """Conditionally include/exclude extra timestamps.""" - match self: - case self.incoming: - return query.where( - col(Placement.canceled_at).is_(None) - & col(Placement.completed_at).is_(None) + update_query = ( + sqlmodel.update(Placement) + .where( + (col(Placement.placement_id) == placement_id) + & sqlmodel.select( + sqlmodel.func.count(col(PlacedItem.item_no)) + == sqlmodel.func.count(col(PlacedItem.supplied_at)) ) - case self.resolved: - return query.where( - col(Placement.canceled_at).isnot(None) - | col(Placement.completed_at).isnot(None) - ).add_columns( - unixepoch(col(Placement.canceled_at)), - unixepoch(col(Placement.completed_at)), - ) - - def placements_callbacks( - self, placements: list[placement_t] - ) -> tuple[ - Callable[[int, Mapping], None], - Callable[[Mapping], item_t], - Callable[[list[item_t]], None], - ]: - match self: - case self.incoming: - - def init_cb(placement_id: int, map: Mapping) -> None: - placements.append( - { - "placement_id": placement_id, - "placed_at": _to_time(map["placed_at"]), - } - ) - - def elem_cb(map: Mapping) -> item_t: - supplied_at = map["supplied_at"] - return { - "product_id": map["product_id"], - "count": map["count"], - "name": map["name"], - "supplied_at": _to_time(supplied_at) if supplied_at else None, - } - - def list_cb(items: list[item_t]) -> None: - if len(placements) > 0: - placements[-1]["items_"] = items - - return init_cb, elem_cb, list_cb - - case self.resolved: - total_price = 0 - - def init_cb(placement_id: int, map: Mapping) -> None: - canceled_at, completed_at = map["canceled_at"], map["completed_at"] - placements.append( - { - "placement_id": placement_id, - "placed_at": _to_time(map["placed_at"]), - "canceled_at": _to_time(canceled_at) - if canceled_at - else None, - "completed_at": _to_time(completed_at) - if completed_at - else None, - } - ) - nonlocal total_price - total_price = 0 - - def elem_cb(map: Mapping) -> item_t: - count, price = map["count"], map["price"] - nonlocal total_price - total_price += count * price - supplied_at = map["supplied_at"] - return { - "product_id": map["product_id"], - "count": count, - "name": map["name"], - "price": Product.to_price_str(price), - "supplied_at": _to_time(supplied_at) if supplied_at else None, - } - - def list_cb(items: list[item_t]) -> None: - if len(placements) > 0: - placements[-1]["items_"] = items - placements[-1]["total_price"] = Product.to_price_str( - total_price - ) - - return init_cb, elem_cb, list_cb - - -async def _agen_query_executor[T]( - db: Database, - query: str, - unique_key: Literal["placement_id"] | Literal["product_id"], - init_cb: Callable[[Any, Mapping], None], - elem_cb: Callable[[Mapping], T], - list_cb: Callable[[list[T]], None], -): - prev_unique_id = -1 - lst: list[T] = list() - async for map in db.iterate(query): - if (unique_id := map[unique_key]) != prev_unique_id: - prev_unique_id = unique_id - list_cb(lst) - init_cb(unique_id, map) - lst: list[T] = list() - lst.append(elem_cb(map)) - list_cb(lst) - - -def _placements_loader( - db: Database, status: PlacementsQuery -) -> Callable[[], Awaitable[list[placement_t]]]: - placements: list[placement_t] = [] - - query = str(status.placements().compile()) - init_cb, elem_cb, list_cb = status.placements_callbacks(placements) - load_placements = partial( - _agen_query_executor, db, query, "placement_id", init_cb, elem_cb, list_cb - ) - - async def load(): - placements.clear() - await load_placements() - return placements - - return load - - -load_incoming_placements = _placements_loader(database, PlacementsQuery.incoming) -load_resolved_placements = _placements_loader(database, PlacementsQuery.resolved) - - -async def load_one_resolved_placement(placement_id: int) -> placement_t | None: - query = PlacementsQuery.resolved.placements().where( - col(Placement.placement_id) == placement_id - ) - - rows_agen = database.iterate(query) - if (row := await anext(rows_agen, None)) is None: - return None - - canceled_at, completed_at = row["canceled_at"], row["completed_at"] - placement: placement_t = { - "placement_id": placement_id, - "placed_at": _to_time(row["placed_at"]), - "canceled_at": _to_time(canceled_at) if canceled_at else None, - "completed_at": _to_time(completed_at) if completed_at else None, - } - - total_price = 0 - - def to_item(row: Mapping) -> item_t: - count, price = row["count"], row["price"] - nonlocal total_price - total_price += count * price - supplied_at = row["supplied_at"] - return { - "product_id": row["product_id"], - "count": count, - "name": row["name"], - "price": Product.to_price_str(price), - "supplied_at": _to_time(supplied_at) if supplied_at else None, - } - - items = [to_item(row)] - async for row in rows_agen: - items.append(to_item(row)) - placement["items_"] = items - placement["total_price"] = Product.to_price_str(total_price) - - return placement - - -# NOTE:get placements by incoming order in datetime -# -# async def select_placements_by_incoming_order() -> dict[int, list[dict]]: -# query = f""" -# SELECT -# {PlacedItem.placement_id}, -# {PlacedItem.product_id}, -# {Product.name}, -# {Product.filename} -# FROM {PlacedItem.__tablename__} -# JOIN {Product.__tablename__} as {Product.__name__} ON {PlacedItem.product_id} = {Product.product_id} -# ORDER BY {PlacedItem.placement_id} ASC, {PlacedItem.item_no} ASC; -# """ -# placements: dict[int, list[dict]] = {} -# async for row in db.iterate(query): -# print(dict(row)) -# return placements - - -type placed_item_t = dict[str, int | str | list[dict[str, int | str]]] - - -def _placed_items_loader(db: Database) -> Callable[[], Awaitable[list[placed_item_t]]]: - query = ( - sqlmodel.select( - PlacedItem.placement_id, - sqlmodel.func.count(col(PlacedItem.product_id)).label("count"), - PlacedItem.product_id, - ) - .where(col(PlacedItem.supplied_at).is_(None)) # Filter out supplied items - .group_by(col(PlacedItem.placement_id), col(PlacedItem.product_id)) - .select_from(sqlmodel.join(PlacedItem, Product)) - .add_columns(col(Product.name), col(Product.filename)) - .join(Placement) - .add_columns(unixepoch(col(Placement.placed_at))) - ) - - query = ( - PlacementsQuery.incoming._extra_timestamps(query) - .order_by(col(PlacedItem.product_id).asc()) - .order_by(col(PlacedItem.placement_id).asc()) - ) - - query_str = str(query.compile()) - - placed_items: list[placed_item_t] = [] - - def init_cb(product_id: int, map: Mapping): - placed_items.append( - {"product_id": product_id, "name": map["name"], "filename": map["filename"]} + .where(col(PlacedItem.placement_id) == placement_id) + .scalar_subquery() + ) + .returning(col(Placement.placement_id).isnot(None)) ) - def elem_cb(map: Mapping) -> dict[str, int | str]: - return { - "placement_id": map["placement_id"], - "count": map["count"], - "placed_at": _to_time(map["placed_at"]), - } - - def list_cb(placements: list[dict[str, int | str]]): - if len(placed_items) > 0: - placed_items[-1]["placements"] = placements - - load_placed_products = partial( - _agen_query_executor, db, query_str, "product_id", init_cb, elem_cb, list_cb - ) - - async def load(): - placed_items.clear() - await load_placed_products() - return placed_items - - return load - + values = {"completed_at": datetime.now(timezone.utc)} + completed: bool | None = await database.fetch_val(update_query, values) -load_placed_items_incoming = _placed_items_loader(database) + async with PlacementTable.modified_cond_flag: + flag = ModifiedFlag.SUPPLIED + if completed is not None: + flag |= ModifiedFlag.RESOLVED + PlacementTable.modified_cond_flag.notify_all(flag) async def supply_all_and_complete(placement_id: int): async with database.transaction(): await PlacedItemTable._supply_all(placement_id) await PlacementTable._complete(placement_id) - async with PlacementTable.modified: - PlacementTable.modified.notify_all() - - -async def supply_and_complete_placement_if_done(placement_id: int, product_id: int): - async with database.transaction(): - await PlacedItemTable._supply(placement_id, product_id) - - update_query = sqlmodel.update(Placement).where( - (col(Placement.placement_id) == placement_id) - & sqlmodel.select( - sqlmodel.func.count(col(PlacedItem.item_no)) - == sqlmodel.func.count(col(PlacedItem.supplied_at)) - ) - .where(col(PlacedItem.placement_id) == placement_id) - .scalar_subquery() - ) - values = {"completed_at": datetime.now(timezone.utc)} - await database.execute(update_query, values) - - async with PlacementTable.modified: - PlacementTable.modified.notify_all() + async with PlacementTable.modified_cond_flag: + FLAG = ModifiedFlag.SUPPLIED | ModifiedFlag.RESOLVED + PlacementTable.modified_cond_flag.notify_all(FLAG) async def _startup_db() -> None: diff --git a/app/store/placed_item.py b/app/store/placed_item.py index 07574bc..2ead8cd 100644 --- a/app/store/placed_item.py +++ b/app/store/placed_item.py @@ -65,7 +65,7 @@ async def _supply(self, placement_id: int, product_id: int): async def _supply_all(self, placement_id: int): """ - Use `store.supply_all_and_complete` when the `completed_at` fields of + Use `supply_all_and_complete` when the `completed_at` fields of `placements` table should be updated as well. """ clause = col(PlacedItem.placement_id) == placement_id diff --git a/app/store/placement.py b/app/store/placement.py index 5e63de2..62cb15c 100644 --- a/app/store/placement.py +++ b/app/store/placement.py @@ -1,5 +1,6 @@ import asyncio from datetime import datetime, timezone +from enum import Flag, auto from typing import Annotated import sqlalchemy @@ -28,8 +29,39 @@ class Placement(sqlmodel.SQLModel, table=True): ) +class ModifiedFlag(Flag): + ORIGINAL = auto() + INCOMING = auto() + SUPPLIED = auto() + RESOLVED = auto() + PUT_BACK = auto() + + +class ModifiedCondFlag: + _condvar: asyncio.Condition = asyncio.Condition() + flag: ModifiedFlag = ModifiedFlag.ORIGINAL + + async def __aenter__(self): + await self._condvar.__aenter__() + + async def __aexit__(self, exc_type, exc, tb) -> None: + await self._condvar.__aexit__(exc_type, exc, tb) + + async def wait(self) -> ModifiedFlag: + await self._condvar.wait() + flag = self.flag + if len(self._condvar._waiters) == 0: + self.flag = ModifiedFlag.ORIGINAL + return flag + + def notify_all(self, flag: ModifiedFlag | None = None): + self._condvar.notify_all() + if flag is not None: + self.flag |= flag + + class Table: - modified: asyncio.Condition = asyncio.Condition() + modified_cond_flag = ModifiedCondFlag() def __init__(self, database: Database): self._db = database @@ -37,8 +69,8 @@ def __init__(self, database: Database): async def insert(self, placement_id: int) -> None: query = sqlmodel.insert(Placement) await self._db.execute(query, {"placement_id": placement_id}) - async with self.modified: - self.modified.notify_all() + async with self.modified_cond_flag: + self.modified_cond_flag.notify_all(ModifiedFlag.INCOMING) @staticmethod def _update(placement_id: int) -> sqlalchemy.Update: @@ -48,12 +80,12 @@ def _update(placement_id: int) -> sqlalchemy.Update: async def cancel(self, placement_id: int) -> None: values = {"canceled_at": datetime.now(timezone.utc), "completed_at": None} await self._db.execute(self._update(placement_id), values) - async with self.modified: - self.modified.notify_all() + async with self.modified_cond_flag: + self.modified_cond_flag.notify_all(ModifiedFlag.RESOLVED) async def _complete(self, placement_id: int) -> None: """ - Use `store.supply_all_and_complete` when the `supplied_at` fields of + Use `supply_all_and_complete` when the `supplied_at` fields of `placed_items` table should be updated as well. """ values = {"canceled_at": None, "completed_at": datetime.now(timezone.utc)} @@ -62,8 +94,8 @@ async def _complete(self, placement_id: int) -> None: async def reset(self, placement_id: int) -> None: values = {"canceled_at": None, "completed_at": None} await self._db.execute(self._update(placement_id), values) - async with self.modified: - self.modified.notify_all() + async with self.modified_cond_flag: + self.modified_cond_flag.notify_all(ModifiedFlag.PUT_BACK) async def by_placement_id(self, placement_id: int) -> Placement | None: query = sqlmodel.select(Placement).where(Placement.placement_id == placement_id) diff --git a/app/store/product.py b/app/store/product.py index e7e20a3..d124abd 100644 --- a/app/store/product.py +++ b/app/store/product.py @@ -1,8 +1,6 @@ import csv from typing import Annotated, Iterable -from uuid import UUID, uuid4 -import pydantic import sqlmodel from databases import Database from sqlmodel import col @@ -27,47 +25,6 @@ def to_price_str(price: int) -> str: return f"¥{price:,}" -class OrderSession(pydantic.BaseModel): - class CountedProduct(pydantic.BaseModel): - name: str - price: str - count: int = pydantic.Field(default=1) - - items: dict[UUID, Product] = pydantic.Field(default_factory=dict) - counted_products: dict[int, CountedProduct] = pydantic.Field(default_factory=dict) - total_count: int = pydantic.Field(default=0) - total_price: int = pydantic.Field(default=0) - - def clear(self): - self.total_count = 0 - self.total_price = 0 - self.items = {} - self.counted_products = {} - - def total_price_str(self) -> str: - return Product.to_price_str(self.total_price) - - def add(self, p: Product): - self.total_count += 1 - self.total_price += p.price - self.items[uuid4()] = p - if p.product_id in self.counted_products: - self.counted_products[p.product_id].count += 1 - else: - counted_product = self.CountedProduct(name=p.name, price=p.price_str()) - self.counted_products[p.product_id] = counted_product - - def delete(self, item_id: UUID): - if item_id in self.items: - self.total_count -= 1 - product = self.items.pop(item_id) - self.total_price -= product.price - if self.counted_products[product.product_id].count == 1: - self.counted_products.pop(product.product_id) - else: - self.counted_products[product.product_id].count -= 1 - - class Table: def __init__(self, database: Database): self._db = database diff --git a/app/store/test__init__.py b/app/store/test__init__.py deleted file mode 100644 index fe55536..0000000 --- a/app/store/test__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -import sqlalchemy -from inline_snapshot import snapshot - -from . import PlacementsQuery - - -def strip_lines(query: sqlalchemy.Select) -> str: - stripped_lines = [line.strip() for line in str(query).split("\n")] - return "\n".join(stripped_lines) - - -def test_incoming_placements_query(): - assert strip_lines(PlacementsQuery.incoming.placements()) == snapshot( - """\ -SELECT placements.placement_id, unixepoch(placements.placed_at) AS placed_at, placed_items.product_id, unixepoch(placed_items.supplied_at) AS supplied_at, count(placed_items.product_id) AS count, products.name -FROM placements JOIN placed_items ON placements.placement_id = placed_items.placement_id JOIN products ON products.product_id = placed_items.product_id -WHERE placements.canceled_at IS NULL AND placements.completed_at IS NULL GROUP BY placements.placement_id, placed_items.product_id ORDER BY placements.placement_id ASC, placed_items.product_id ASC\ -""" - ) - - -def test_resolved_placements_query(): - assert strip_lines(PlacementsQuery.resolved.placements()) == snapshot( - """\ -SELECT placements.placement_id, unixepoch(placements.placed_at) AS placed_at, unixepoch(placements.canceled_at) AS canceled_at, unixepoch(placements.completed_at) AS completed_at, placed_items.product_id, unixepoch(placed_items.supplied_at) AS supplied_at, count(placed_items.product_id) AS count, products.name, products.price -FROM placements JOIN placed_items ON placements.placement_id = placed_items.placement_id JOIN products ON products.product_id = placed_items.product_id -WHERE placements.canceled_at IS NOT NULL OR placements.completed_at IS NOT NULL GROUP BY placements.placement_id, placed_items.product_id ORDER BY placements.placement_id ASC, placed_items.product_id ASC\ -""" - ) diff --git a/app/templates.py b/app/templates.py index 6f9160b..b1fff2e 100644 --- a/app/templates.py +++ b/app/templates.py @@ -1,7 +1,7 @@ import os from functools import wraps from pathlib import Path -from typing import Any, Callable, Protocol, Optional +from typing import Any, Callable, Protocol import jinja2 from fastapi import Request @@ -9,8 +9,6 @@ from jinja2.ext import debug as debug_ext from .env import DEBUG -from .store import Product, placed_item_t, placement_t -from .store.product import OrderSession TEMPLATES_DIR = Path("app/templates") @@ -60,7 +58,7 @@ def macro_template[**P]( def type_signature(fn: _MacroArgHints[P]) -> _RenderMacroWithRequest[P]: @wraps(fn) - def with_request(request, *args: P.args, **kwargs: P.kwargs) -> str: + def with_request(request: Request, *args: P.args, **kwargs: P.kwargs) -> str: template = env.get_template(name, globals={"request": request, **globals}) return load_macro(template, macro_name)(*args, **kwargs) @@ -75,95 +73,5 @@ def layout( ): ... -@macro_template("index.html") -def index(): ... - - -class products: # namespace - @macro_template("products.html") - @staticmethod - def page(products: list[Product]): ... - - @macro_template("products.html", "editor") - @staticmethod - def editor(product: Product): ... - - @macro_template("products.html", "empty_editor") - @staticmethod - def empty_editor(): ... - - -class order: # namespace - @macro_template("order.html") - @staticmethod - def page(products: list[Product], session: OrderSession): ... - - @macro_template("order.html", "order_session") - @staticmethod - def session(session: OrderSession): ... - - -class placed_items_incoming: # namespace - @macro_template("placed-items-incoming.html") - @staticmethod - def page(placed_items: list[placed_item_t]): ... - - @macro_template("placed-items-incoming.html", "component") - @staticmethod - def component(placed_items: list[placed_item_t]): ... - - -class incoming_placements: # namespace - @macro_template("incoming-placements.html") - @staticmethod - def page(placements: list[placement_t]): ... - - @macro_template("incoming-placements.html", "component") - @staticmethod - def component(placements: list[placement_t]): ... - - -class resolved_placements: # namespace - @macro_template("resolved-placements.html") - @staticmethod - def page(placements: list[placement_t]): ... - - @macro_template("resolved-placements.html", "completed") - @staticmethod - def completed(placement: placement_t): ... - - @macro_template("resolved-placements.html", "canceled") - @staticmethod - def canceled(placement: placement_t): ... - - @macro_template("hx-post.html") def hx_post(path: str): ... - - -@macro_template("stat.html") -def stat( - total_sales_all_time: int, - total_sales_today: int, - total_items_all_time: int, - total_items_today: int, - sales_summary_list: list[dict[str, Any]], - average_service_time_all: str, - average_service_time_recent: str, -): ... - - -@macro_template("wait-estimates.html") -def wait_estimates(): ... - - -class components: # namespace - @macro_template("components/order-confirm.html") - @staticmethod - def order_confirm(session: OrderSession, error_status: Optional[str]): ... - - @macro_template("components/order-issued.html") - @staticmethod - def order_issued( - placement_id: Optional[int], session: OrderSession, error_status: Optional[str] - ): ... diff --git a/app/templates/components/order-confirm.html b/app/templates/components/order-confirm.html deleted file mode 100644 index c61d6b5..0000000 --- a/app/templates/components/order-confirm.html +++ /dev/null @@ -1,58 +0,0 @@ -{% macro order_confirm(session, error_status) %} - -{% endmacro %} diff --git a/app/templates/components/order-issued.html b/app/templates/components/order-issued.html deleted file mode 100644 index 39861ee..0000000 --- a/app/templates/components/order-issued.html +++ /dev/null @@ -1,58 +0,0 @@ -{% macro order_issued(placement_id, session, error_status) %} - -{% endmacro %} diff --git a/app/templates/hx-post.html b/app/templates/hx-post.html index 7642b49..be9b176 100644 --- a/app/templates/hx-post.html +++ b/app/templates/hx-post.html @@ -2,7 +2,7 @@ {% macro hx_post(path) %} {% call layout("") %} -
+ {% endcall %} {% endmacro %} diff --git a/app/templates/incoming-placements.html b/app/templates/incoming-placements.html index 5cad432..77d4a3d 100644 --- a/app/templates/incoming-placements.html +++ b/app/templates/incoming-placements.html @@ -7,9 +7,9 @@ {% macro incoming_placements(placements) %} {% call layout("未受取注文 - murchace", _head()) %} -
-
-
+{% endmacro %} + +{% macro confirm_modal(session) %} + +{% endmacro %} + +{% macro issued_modal(placement_id, session) %} + +{% endmacro %} + +{% macro _total(counted_products, total_count, total_price) %} + +
+

+ + {{ total_count }} 点 +

+

+ 合計金額 + {{ total_price }} +

+
+{% endmacro %} + +{% macro error_modal(message) %} + {% endmacro %} diff --git a/app/templates/placed-items-incoming.html b/app/templates/placed-items-incoming.html index c3df148..d396ca0 100644 --- a/app/templates/placed-items-incoming.html +++ b/app/templates/placed-items-incoming.html @@ -7,9 +7,9 @@ {% macro placed_items_incoming(placed_items) %} {% call layout("未受取商品 - murchace", _head()) %} -
-
-