Skip to content

Commit

Permalink
Merge pull request #11 from SpareCores/DEV-36
Browse files Browse the repository at this point in the history
DEV-36 bulk inserts
  • Loading branch information
daroczig authored Mar 12, 2024
2 parents 673a884 + bba23cb commit 8341276
Show file tree
Hide file tree
Showing 7 changed files with 761 additions and 512 deletions.
109 changes: 109 additions & 0 deletions src/sc_crawler/insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from logging import DEBUG
from typing import List, Optional

from sqlalchemy.dialects.sqlite import insert
from sqlmodel import SQLModel

from .schemas import Vendor
from .str import space_after
from .utils import chunk_list, is_sqlite


def validate_items(
model: SQLModel,
items: List[dict],
vendor: Optional[Vendor] = None,
prefix: str = "",
) -> List[dict]:
"""Validates a list of items against a SQLModel definition.
Args:
model: An SQLModel model to be used for validation.
items: List of dictionaries to be checked against `model`.
vendor: Optional Vendor instance used for logging and progress bar updates.
prefix: Optional extra description for the model added in front of
the model name in logs and progress bar updates.
Returns:
List of validated dicts in the same order. Note that missing fields
has been filled in with default values (needed for bulk inserts).
"""
model_name = model.get_table_name()
if vendor:
vendor.progress_tracker.start_task(
name=f"Validating {space_after(prefix)}{model_name}(s)", n=len(items)
)
for i, item in enumerate(items):
items[i] = model.model_validate(item).model_dump()
if vendor:
vendor.progress_tracker.advance_task()
if vendor:
vendor.progress_tracker.hide_task()
vendor.log(
"%d {space_after(prefix)}%s(s) objects validated"
% (len(items), model_name),
DEBUG,
)
return items


def bulk_insert_items(
model: SQLModel, items: List[dict], vendor: Vendor, prefix: str = ""
):
"""Bulk inserts items into a SQLModel table with ON CONFLICT update.
Args:
model: An SQLModel table definition with primary key(s).
items: List of dicts with all columns of the model.
vendor: The related Vendor instance used for database connection, logging and progress bar updates.
prefix: Optional extra description for the model added in front of
the model name in logs and progress bar updates.
"""
model_name = model.get_table_name()
columns = model.get_columns()
vendor.progress_tracker.start_task(
name=f"Syncing {space_after(prefix)}{model_name}(s)", n=len(items)
)
# need to split list into smaller chunks to avoid "too many SQL variables"
for chunk in chunk_list(items, 100):
query = insert(model).values(chunk)
query = query.on_conflict_do_update(
index_elements=[getattr(model, c) for c in columns["primary_keys"]],
set_={c: query.excluded[c] for c in columns["attributes"]},
)
vendor.session.execute(query)
vendor.progress_tracker.advance_task(by=len(chunk))
vendor.progress_tracker.hide_task()
vendor.log(f"{len(items)} {space_after(prefix)}{model_name}(s) synced.")


def insert_items(model: SQLModel, items: List[dict], vendor: Vendor, prefix: str = ""):
"""Insert items into the related database table using bulk or merge.
Bulk insert is only supported with SQLite, other databases fall back to
the default session.merge (slower) approach.
Args:
model: An SQLModel table definition with primary key(s).
items: List of dicts with all columns of the model.
vendor: The related Vendor instance used for database connection, logging and progress bar updates.
prefix: Optional extra description for the model added in front of
the model name in logs and progress bar updates.
"""
model_name = model.get_table_name()
if is_sqlite(vendor.session):
items = validate_items(model, items, vendor, prefix)
bulk_insert_items(model, items, vendor, prefix)
else:
vendor.progress_tracker.start_task(
name=f"Syncing {space_after(prefix)}{model_name}(s)", n=len(items)
)
for item in items:
# vendor's auto session.merge doesn't work due to SQLmodel bug:
# - https://github.com/tiangolo/sqlmodel/issues/6
# - https://github.com/tiangolo/sqlmodel/issues/342
# so need to trigger the merge manually
vendor.session.merge(model.model_validate(item))
vendor.progress_tracker.advance_task()
vendor.progress_tracker.hide_task()
vendor.log(f"{len(items)} {space_after(prefix)}{model_name}(s) synced.")
37 changes: 14 additions & 23 deletions src/sc_crawler/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,22 @@ def __tablename__(cls) -> str:
"""Generate tables names using all-lowercase snake_case."""
return snake_case(cls.__name__)

@classmethod
def get_columns(cls) -> List[str]:
"""Return the table's column names in a dict for all, primary keys, and attributes."""
columns = cls.__table__.columns.keys()
pks = [pk.name for pk in inspect(cls).primary_key]
attributes = [a for a in columns if a not in set(pks)]
return {"all": columns, "primary_keys": pks, "attributes": attributes}

@classmethod
def get_table_name(cls) -> str:
"""Return the SQLModel object's table name."""
return str(cls.__tablename__)

@classmethod
def hash(cls, session, ignored: List[str] = ["observed_at"]) -> dict:
pks = sorted([key.name for key in inspect(cls).primary_key])
pks = sorted(cls.get_columns()["primary_keys"])
rows = session.exec(statement=select(cls))
# no use of a generator as will need to serialize to JSON anyway
hashes = {}
Expand All @@ -102,18 +110,6 @@ def hash(cls, session, ignored: List[str] = ["observed_at"]) -> dict:
hashes[rowkeys] = rowhash
return hashes

def __init__(self, *args, **kwargs):
"""Merge instace with the database if present.
Checking if there's a parent vendor, and then try to sync the
object using the parent's session private attribute.
"""
super().__init__(*args, **kwargs)
if hasattr(self, "vendor"):
if self.vendor:
if self.vendor.session:
self.vendor.merge_dependent(self)


class Json(BaseModel):
"""Custom base SQLModel class that supports dumping as JSON."""
Expand Down Expand Up @@ -260,16 +256,18 @@ class HasTraffic(ScModel):
# Actual SC data schemas and model definitions


class Country(ScModel, table=True):
"""Country and continent mapping."""

class CountryBase(ScModel):
id: str = Field(
default=None,
primary_key=True,
description="Country code by ISO 3166 alpha-2.",
)
continent: str = Field(description="Continent name.")


class Country(CountryBase, table=True):
"""Country and continent mapping."""

vendors: List["Vendor"] = Relationship(back_populates="country")
datacenters: List["Datacenter"] = Relationship(back_populates="country")

Expand Down Expand Up @@ -492,13 +490,6 @@ def register_progress_tracker(self, progress_tracker: VendorProgressTracker):
"""Attach a VendorProgressTracker to use for updating progress bars."""
self._progress_tracker = progress_tracker

def merge_dependent(self, obj):
"""Merge an object into the Vendor's SQLModel session (when available)."""
if self.session:
# TODO investigate SAWarning
# on obj associated with vendor before added to session?
self.session.merge(obj)

def set_table_rows_inactive(self, model: str, *args) -> None:
"""Set this vendor's records to INACTIVE in a table
Expand Down
23 changes: 20 additions & 3 deletions src/sc_crawler/str.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from re import search, sub
from typing import Union


def wrap(text: str = "", before: str = " ", after: str = " ") -> str:
"""Wrap string between before/after strings (default to spaces) if not empty."""
return text if text == "" else before + text + after


def space_after(text: str = ""):
"""Add space after string if not empty."""
return wrap(text, before="")


# https://www.w3resource.com/python-exercises/string/python-data-type-string-exercise-97.php
Expand Down Expand Up @@ -32,14 +43,20 @@ def plural(text):
return text + "s"


def extract_last_number(s: str) -> float:
"""Extract the last number from a string.
def extract_last_number(text: str) -> Union[float, None]:
"""Extract the last non-negative number from a string.
Args:
text: The input string from which to extract the number.
Returns:
The last non-negative number found in the string, or None if no number is found.
Examples:
>>> extract_last_number("foo42")
42.0
>>> extract_last_number("foo24.42bar")
24.42
"""
match = search(r"([\d\.]+)[^0-9]*$", str(s))
match = search(r"([\d\.]+)[^0-9]*$", text)
return float(match.group(1)) if match else None
50 changes: 48 additions & 2 deletions src/sc_crawler/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from enum import Enum
from hashlib import sha1
from json import dumps
from typing import List, Union
from typing import Any, Dict, Iterable, List, Union

from sqlmodel import Session, create_engine

from .schemas import tables
from .schemas import ScModel, tables


def jsoned_hash(*args, **kwargs):
Expand Down Expand Up @@ -60,3 +60,49 @@ def hash_database(
hashes = jsoned_hash(hashes)

return hashes


def chunk_list(items: List[Any], size: int) -> Iterable[List[Any]]:
"""Split a list into chunks of a specified size.
Examples:
>>> [len(x) for x in chunk_list(range(10), 3)]
[3, 3, 3, 1]
"""
for i in range(0, len(items), size):
yield items[i : i + size]


def scmodels_to_dict(
scmodels: List[ScModel], keys: List[str] = ["id"]
) -> Dict[str, ScModel]:
"""Creates a dict indexed by key(s) of the ScModels of the list.
When multiple keys are provided, each ScModel instance will be stored in
the dict with all keys. If a key is a list, then each list element is
considered (not recursively, only at first level) as a key.
Conflict of keys is not checked.
Args:
scmodels: list of ScModel instances
key: a list of strings referring to ScModel fields to be used as keys
Examples:
>>> from sc_crawler.vendors import aws
>>> scmodels_to_dict([aws], keys=["id", "name"])
{'aws': Vendor...
"""
data = {}
for key in keys:
for scmodel in scmodels:
data_keys = getattr(scmodel, key)
if not isinstance(data_keys, list):
data_keys = [data_keys]
for data_key in data_keys:
data[data_key] = scmodel
return data


def is_sqlite(session: Session) -> bool:
"""Checks if a SQLModel session is binded to SQLite or another database."""
return session.bind.dialect.name == "sqlite"
Loading

0 comments on commit 8341276

Please sign in to comment.