-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from SpareCores/DEV-36
DEV-36 bulk inserts
- Loading branch information
Showing
7 changed files
with
761 additions
and
512 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from logging import DEBUG | ||
from typing import List, Optional | ||
|
||
from sqlalchemy.dialects.sqlite import insert | ||
from sqlmodel import SQLModel | ||
|
||
from .schemas import Vendor | ||
from .str import space_after | ||
from .utils import chunk_list, is_sqlite | ||
|
||
|
||
def validate_items( | ||
model: SQLModel, | ||
items: List[dict], | ||
vendor: Optional[Vendor] = None, | ||
prefix: str = "", | ||
) -> List[dict]: | ||
"""Validates a list of items against a SQLModel definition. | ||
Args: | ||
model: An SQLModel model to be used for validation. | ||
items: List of dictionaries to be checked against `model`. | ||
vendor: Optional Vendor instance used for logging and progress bar updates. | ||
prefix: Optional extra description for the model added in front of | ||
the model name in logs and progress bar updates. | ||
Returns: | ||
List of validated dicts in the same order. Note that missing fields | ||
has been filled in with default values (needed for bulk inserts). | ||
""" | ||
model_name = model.get_table_name() | ||
if vendor: | ||
vendor.progress_tracker.start_task( | ||
name=f"Validating {space_after(prefix)}{model_name}(s)", n=len(items) | ||
) | ||
for i, item in enumerate(items): | ||
items[i] = model.model_validate(item).model_dump() | ||
if vendor: | ||
vendor.progress_tracker.advance_task() | ||
if vendor: | ||
vendor.progress_tracker.hide_task() | ||
vendor.log( | ||
"%d {space_after(prefix)}%s(s) objects validated" | ||
% (len(items), model_name), | ||
DEBUG, | ||
) | ||
return items | ||
|
||
|
||
def bulk_insert_items( | ||
model: SQLModel, items: List[dict], vendor: Vendor, prefix: str = "" | ||
): | ||
"""Bulk inserts items into a SQLModel table with ON CONFLICT update. | ||
Args: | ||
model: An SQLModel table definition with primary key(s). | ||
items: List of dicts with all columns of the model. | ||
vendor: The related Vendor instance used for database connection, logging and progress bar updates. | ||
prefix: Optional extra description for the model added in front of | ||
the model name in logs and progress bar updates. | ||
""" | ||
model_name = model.get_table_name() | ||
columns = model.get_columns() | ||
vendor.progress_tracker.start_task( | ||
name=f"Syncing {space_after(prefix)}{model_name}(s)", n=len(items) | ||
) | ||
# need to split list into smaller chunks to avoid "too many SQL variables" | ||
for chunk in chunk_list(items, 100): | ||
query = insert(model).values(chunk) | ||
query = query.on_conflict_do_update( | ||
index_elements=[getattr(model, c) for c in columns["primary_keys"]], | ||
set_={c: query.excluded[c] for c in columns["attributes"]}, | ||
) | ||
vendor.session.execute(query) | ||
vendor.progress_tracker.advance_task(by=len(chunk)) | ||
vendor.progress_tracker.hide_task() | ||
vendor.log(f"{len(items)} {space_after(prefix)}{model_name}(s) synced.") | ||
|
||
|
||
def insert_items(model: SQLModel, items: List[dict], vendor: Vendor, prefix: str = ""): | ||
"""Insert items into the related database table using bulk or merge. | ||
Bulk insert is only supported with SQLite, other databases fall back to | ||
the default session.merge (slower) approach. | ||
Args: | ||
model: An SQLModel table definition with primary key(s). | ||
items: List of dicts with all columns of the model. | ||
vendor: The related Vendor instance used for database connection, logging and progress bar updates. | ||
prefix: Optional extra description for the model added in front of | ||
the model name in logs and progress bar updates. | ||
""" | ||
model_name = model.get_table_name() | ||
if is_sqlite(vendor.session): | ||
items = validate_items(model, items, vendor, prefix) | ||
bulk_insert_items(model, items, vendor, prefix) | ||
else: | ||
vendor.progress_tracker.start_task( | ||
name=f"Syncing {space_after(prefix)}{model_name}(s)", n=len(items) | ||
) | ||
for item in items: | ||
# vendor's auto session.merge doesn't work due to SQLmodel bug: | ||
# - https://github.com/tiangolo/sqlmodel/issues/6 | ||
# - https://github.com/tiangolo/sqlmodel/issues/342 | ||
# so need to trigger the merge manually | ||
vendor.session.merge(model.model_validate(item)) | ||
vendor.progress_tracker.advance_task() | ||
vendor.progress_tracker.hide_task() | ||
vendor.log(f"{len(items)} {space_after(prefix)}{model_name}(s) synced.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.