diff --git a/dbs/neo4j/api/routers/wine.py b/dbs/neo4j/api/routers/wine.py index 27d2b42..3f5a955 100644 --- a/dbs/neo4j/api/routers/wine.py +++ b/dbs/neo4j/api/routers/wine.py @@ -1,11 +1,10 @@ from fastapi import APIRouter, HTTPException, Query, Request from neo4j import AsyncManagedTransaction - from schemas.retriever import ( FullTextSearch, + MostWinesByVariety, TopWinesByCountry, TopWinesByProvince, - MostWinesByVariety, ) wine_router = APIRouter() diff --git a/dbs/neo4j/schemas/wine.py b/dbs/neo4j/schemas/wine.py index 10be66a..92924df 100644 --- a/dbs/neo4j/schemas/wine.py +++ b/dbs/neo4j/schemas/wine.py @@ -1,26 +1,6 @@ from pydantic import BaseModel, root_validator -class Location(BaseModel): - country: str | None - province: str | None - region_1: str | None - region_2: str | None - - @root_validator - def _fill_country_unknowns(cls, values): - "Fill in missing country values with 'Unknown'" - country = values.get("country") - if not country: - values["country"] = "Unknown" - return values - - -class Taster(BaseModel): - taster_name: str | None - taster_twitter_handle: str | None - - class Wine(BaseModel): id: int points: int @@ -68,19 +48,9 @@ def _get_vineyard(cls, values): return values @root_validator - def _add_location_dict(cls, values): - "Convert location attributes to a nested dict" - location = Location(**values).dict() - values["location"] = location - for key in location.keys(): - values.pop(key) - return values - - @root_validator - def _add_taster_dict(cls, values): - "Convert taster attributes to a nested dict" - taster = Taster(**values).dict() - values["taster"] = taster - for key in taster.keys(): - values.pop(key) + def _fill_country_unknowns(cls, values): + "Fill in missing country values with 'Unknown', as we always want this field to be queryable" + country = values.get("country") + if not country: + values["country"] = "Unknown" return values diff --git a/dbs/neo4j/scripts/build_graph.py b/dbs/neo4j/scripts/build_graph.py index 751c90c..df15bf6 100644 --- a/dbs/neo4j/scripts/build_graph.py +++ b/dbs/neo4j/scripts/build_graph.py @@ -9,9 +9,10 @@ from typing import Any from dotenv import load_dotenv -from neo4j import AsyncGraphDatabase, AsyncManagedTransaction, AsyncSession from pydantic.main import ModelMetaclass +from neo4j import AsyncGraphDatabase, AsyncManagedTransaction, AsyncSession + sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1])) from schemas.wine import Wine @@ -90,74 +91,37 @@ async def create_indexes_and_constraints(session: AsyncSession) -> None: await session.run(query) -async def wine_nodes(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: +async def build_query(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: query = """ - UNWIND $data AS d - MERGE (wine:Wine {wineID: d.id}) + UNWIND $data AS record + MERGE (wine:Wine {wineID: record.id}) SET wine += { - points: toInteger(d.points), - title: d.title, - description: d.description, - price: toFloat(d.price), - variety: d.variety, - winery: d.winery + points: record.points, + title: record.title, + description: record.description, + price: record.price, + variety: record.variety, + winery: record.winery, + vineyard: record.vineyard, + region_1: record.region_1, + region_2: record.region_2 } - """ - await tx.run(query, data=data) - - -async def wine_country_rels(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: - query = """ - UNWIND $data AS d - MATCH (wine:Wine {wineID: d.id}) - UNWIND d.location as loc - WITH wine, loc - WHERE loc.country IS NOT NULL - MERGE (country:Country {countryName: loc.country}) + WITH record, wine + WHERE record.taster_name IS NOT NULL + MERGE (taster:Person {tasterName: record.taster_name}) + SET taster += {tasterTwitterHandle: record.taster_twitter_handle} + MERGE (wine)-[:TASTED_BY]->(taster) + WITH record, wine + MERGE (country:Country {countryName: record.country}) MERGE (wine)-[:IS_FROM_COUNTRY]->(country) - """ - await tx.run(query, data=data) - - -async def wine_province_rels(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: - query = """ - UNWIND $data AS d - MATCH (wine:Wine {wineID: d.id}) - UNWIND d.location as loc - WITH wine, loc - WHERE loc.province IS NOT NULL - MERGE (province:Province {provinceName: loc.province}) + WITH record, wine, country + WHERE record.province IS NOT NULL + MERGE (province:Province {provinceName: record.province}) MERGE (wine)-[:IS_FROM_PROVINCE]->(province) - """ - await tx.run(query, data=data) - - -async def country_province_rels( - tx: AsyncManagedTransaction, data: list[JsonBlob] -) -> None: - query = """ - UNWIND $data AS d - UNWIND d.location as loc - WITH loc - WHERE loc.province IS NOT NULL AND loc.country IS NOT NULL - MATCH (country:Country {countryName: loc.country}) - MATCH (province:Province {provinceName: loc.province}) + WITH record, wine, country, province + WHERE record.province IS NOT NULL AND record.country IS NOT NULL MERGE (province)-[:IS_LOCATED_IN]->(country) - """ - await tx.run(query, data=data) - - -async def wine_taster_rels(tx: AsyncManagedTransaction, data: list[JsonBlob]) -> None: - query = """ - UNWIND $data AS d - MATCH (wine:Wine {wineID: d.id}) - UNWIND d.taster as t - WITH wine, t - WHERE t.taster_name IS NOT NULL - MERGE (taster:Person {tasterName: t.taster_name}) - SET taster += {tasterTwitterHandle: t.taster_twitter_handle} - MERGE (wine)-[:TASTED_BY]->(taster) - """ + """ await tx.run(query, data=data) @@ -172,11 +136,7 @@ async def main(files: list[str]) -> None: for file in files: data = read_jsonl_from_file(file) validated_data = validate(data, Wine) - await session.execute_write(wine_nodes, validated_data) - await session.execute_write(wine_country_rels, validated_data) - await session.execute_write(wine_province_rels, validated_data) - await session.execute_write(country_province_rels, validated_data) - await session.execute_write(wine_taster_rels, validated_data) + await session.execute_write(build_query, validated_data) print(f"Ingested {Path(file).name} to db")