Skip to content

Commit

Permalink
Merge pull request #34 from prrao87/pydantic-v2
Browse files Browse the repository at this point in the history
Pydantic v2 Elasticsearch
  • Loading branch information
prrao87 authored Jul 17, 2023
2 parents 890ca43 + 9a38fce commit 85ab928
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 66 deletions.
2 changes: 1 addition & 1 deletion dbs/elasticsearch/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ ELASTIC_SERVICE = "elasticsearch"
API_PORT = 8002

# Container image tag
TAG = "0.1.0"
TAG = "0.2.0"

# Docker project namespace (defaults to the current folder name if not set)
COMPOSE_PROJECT_NAME = elastic_wine
10 changes: 6 additions & 4 deletions dbs/elasticsearch/api/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from pydantic import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict


class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
extra="allow",
)

elastic_service: str
elastic_user: str
elastic_password: str
elastic_url: str
elastic_port: int
elastic_index_alias: str
tag: str

class Config:
env_file = ".env"
1 change: 1 addition & 0 deletions dbs/elasticsearch/api/routers/rest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from elasticsearch import AsyncElasticsearch
from fastapi import APIRouter, HTTPException, Query, Request

from schemas.retriever import (
CountByCountry,
FullTextSearch,
Expand Down
8 changes: 5 additions & 3 deletions dbs/elasticsearch/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
elasticsearch>=8.7.0
pydantic[dotenv]>=1.10.7, <2.0.0
fastapi>=0.95.0, <1.0.0
elasticsearch~=8.7.0
pydantic~=2.0.0
pydantic-settings~=2.0.0
python-dotenv>=1.0.0
fastapi~=0.100.0
httpx>=0.24.0
aiohttp>=3.8.4
uvicorn>=0.21.0, <1.0.0
Expand Down
31 changes: 13 additions & 18 deletions dbs/elasticsearch/schemas/retriever.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class FullTextSearch(BaseModel):
id: int
country: str
title: str
description: str | None
points: int
price: float | str | None
variety: str | None
winery: str | None

class Config:
schema_extra = {
model_config = ConfigDict(
json_schema_extra={
"example": {
"id": 3845,
"country": "Italy",
Expand All @@ -24,6 +15,16 @@ class Config:
"winery": "Castellinuzza e Piuca",
}
}
)

id: int
country: str
title: str
description: str | None
points: int
price: float | str | None
variety: str | None
winery: str | None


class TopWinesByCountry(BaseModel):
Expand All @@ -36,9 +37,6 @@ class TopWinesByCountry(BaseModel):
variety: str | None
winery: str | None

class Config:
validate_assignment = True


class TopWinesByProvince(BaseModel):
id: int
Expand All @@ -51,9 +49,6 @@ class TopWinesByProvince(BaseModel):
variety: str | None
winery: str | None

class Config:
validate_assignment = True


class MostWinesByVariety(BaseModel):
country: str
Expand Down
90 changes: 54 additions & 36 deletions dbs/elasticsearch/schemas/wine.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator


class Wine(BaseModel):
id: int
points: int
title: str
description: str | None
price: float | None
variety: str | None
winery: str | None
vineyard: str | None
country: str | None
province: str | None
region_1: str | None
region_2: str | None
taster_name: str | None
taster_twitter_handle: str | None

class Config:
allow_population_by_field_name = True
validate_assignment = True
schema_extra = {
model_config = ConfigDict(
populate_by_name=True,
validate_assignment=True,
extra="allow",
str_strip_whitespace=True,
json_schema_extra={
"example": {
"id": 45100,
"points": 85,
Expand All @@ -37,26 +24,57 @@ class Config:
"taster_name": "Michael Schachner",
"taster_twitter_handle": "@wineschach",
}
}
},
)

@root_validator
def _create_id_field(cls, values):
"Elastic needs an _id field to create unique documents, so we just use the existing id field"
values["_id"] = values["id"]
return values

@root_validator(pre=True)
def _get_vineyard(cls, values):
"Rename designation to vineyard"
vineyard = values.pop("designation", None)
if vineyard:
values["vineyard"] = vineyard.strip()
return values
id: int
points: int
title: str
description: str | None
price: float | None
variety: str | None
winery: str | None
vineyard: str | None = Field(..., alias="designation")
country: str | None
province: str | None
region_1: str | None
region_2: str | None
taster_name: str | None
taster_twitter_handle: str | None

@root_validator
@model_validator(mode="before")
def _fill_country_unknowns(cls, values):
"Fill in missing country values with 'Unknown', as we always want this field to be queryable"
country = values.get("country")
if not country:
if country is None or country == "null":
values["country"] = "Unknown"
return values

@model_validator(mode="before")
def _create_id(cls, values):
"Create an _id field because Elastic needs this to store as primary key"
values["_id"] = values["id"]
return values


if __name__ == "__main__":
data = {
"id": 45100,
"points": 85,
"title": "Balduzzi 2012 Reserva Merlot (Maule Valley)",
"description": "Ripe in color and aromas, this chunky wine delivers heavy baked-berry and raisin aromas in front of a jammy, extracted palate. Raisin and cooked berry flavors finish plump, with earthy notes.",
"price": 10, # Test if field is cast to float
"variety": "Merlot",
"winery": "Balduzzi",
"designation": "Reserva", # Test if field is renamed
"country": "null", # Test unknown country
"province": " Maule Valley ", # Test if field is stripped
"region_1": "null",
"region_2": "null",
"taster_name": "Michael Schachner",
"taster_twitter_handle": "@wineschach",
}
from pprint import pprint

wine = Wine(**data)
pprint(wine.model_dump(), sort_dicts=False)
6 changes: 2 additions & 4 deletions dbs/elasticsearch/scripts/bulk_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import srsly
from dotenv import load_dotenv
from elasticsearch import AsyncElasticsearch, helpers
from pydantic.main import ModelMetaclass

sys.path.insert(1, os.path.realpath(Path(__file__).resolve().parents[1]))
from api.config import Settings
Expand Down Expand Up @@ -59,15 +58,14 @@ def get_json_data(data_dir: Path, filename: str) -> list[JsonBlob]:

def validate(
data: tuple[JsonBlob],
model: ModelMetaclass,
exclude_none: bool = False,
) -> list[JsonBlob]:
validated_data = [model(**item).dict(exclude_none=exclude_none) for item in data]
validated_data = [Wine(**item).dict(exclude_none=exclude_none) for item in data]
return validated_data


def process_chunks(data: list[JsonBlob]) -> tuple[list[JsonBlob], str]:
validated_data = validate(data, Wine, exclude_none=True)
validated_data = validate(data, exclude_none=True)
return validated_data


Expand Down

0 comments on commit 85ab928

Please sign in to comment.