diff --git a/src/sc_crawler/schemas.py b/src/sc_crawler/schemas.py index 2f706ae5..ad56d8a0 100644 --- a/src/sc_crawler/schemas.py +++ b/src/sc_crawler/schemas.py @@ -14,7 +14,7 @@ ImportString, PrivateAttr, ) -from sqlalchemy import DateTime +from sqlalchemy import DateTime, update from sqlalchemy.inspection import inspect from sqlalchemy.orm import declared_attr from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, select @@ -110,6 +110,18 @@ def __init__(self, *args, **kwargs): self.vendor.merge_dependent(self) +class Status(str, Enum): + ACTIVE = "active" + INACTIVE = "inactive" + + +class HasStatus(ScModel): + status: Status = Field( + default=Status.ACTIVE, + description="Status of the resource (active or inactive).", + ) + + class Json(BaseModel): """Custom base SQLModel class that supports dumping as JSON.""" @@ -117,11 +129,6 @@ def __json__(self): return self.model_dump() -class Status(str, Enum): - ACTIVE = "active" - INACTIVE = "inactive" - - class Country(ScModel, table=True): """Country and continent mapping.""" @@ -136,13 +143,15 @@ class Country(ScModel, table=True): datacenters: List["Datacenter"] = Relationship(back_populates="country") -class VendorComplianceLink(ScModel, table=True): +class VendorComplianceLinkBase(ScModel): vendor_id: str = Field(foreign_key="vendor.id", primary_key=True) compliance_framework_id: str = Field( foreign_key="compliance_framework.id", primary_key=True ) comment: Optional[str] = None + +class VendorComplianceLink(HasStatus, VendorComplianceLinkBase, table=True): vendor: "Vendor" = Relationship(back_populates="compliance_framework_links") compliance_framework: "ComplianceFramework" = Relationship( back_populates="vendor_links" @@ -271,55 +280,85 @@ def _get_methods(self): ) from exc return self._methods + def set_session(self, session): + """Attach a SQLModel session to use for merging dependent objects into the database.""" + self._session = session + + def merge_dependent(self, obj): + """Merge an object into the Vendor's SQLModel session (when available).""" + if self._session: + self._session.merge(obj) + + def set_table_rows_inactive(self, model: str, *args) -> None: + """Set this vendor's records to INACTIVE in a table + + Positional arguments can be used to pass further filters + (besides the default model.vendor_id filter) referencing the + model object with SQLModel syntax, e.g. + + >>> aws.set_table_rows_inactive(ServerPrice, ServerPrice.price < 10) + """ + if self._session: + query = update(model).where(model.vendor_id == self.id) + for arg in args: + query = query.where(arg) + self._session.execute(query.values(status=Status.INACTIVE)) + @log_start_end def get_compliance_frameworks(self): """Get the vendor's all compliance frameworks.""" + self.set_table_rows_inactive(VendorComplianceLink) self._get_methods().get_compliance_frameworks(self) @log_start_end def get_datacenters(self): """Get the vendor's all datacenters.""" + self.set_table_rows_inactive(Datacenter) self._get_methods().get_datacenters(self) @log_start_end def get_zones(self): """Get all the zones in the vendor's datacenters.""" + self.set_table_rows_inactive(Zone) self._get_methods().get_zones(self) @log_start_end def get_servers(self): """Get the vendor's all server types.""" + self.set_table_rows_inactive(Server) self._get_methods().get_servers(self) @log_start_end def get_server_prices(self): """Get the current standard/ondemand/reserved prices of all server types.""" + self.set_table_rows_inactive( + ServerPrice, ServerPrice.allocation != Allocation.SPOT + ) self._get_methods().get_server_prices(self) @log_start_end def get_server_prices_spot(self): """Get the current spot prices of all server types.""" + self.set_table_rows_inactive( + ServerPrice, ServerPrice.allocation == Allocation.SPOT + ) self._get_methods().get_server_prices_spot(self) @log_start_end def get_storage_prices(self): + self.set_table_rows_inactive(StoragePrice) self._get_methods().get_storage_prices(self) @log_start_end def get_traffic_prices(self): + self.set_table_rows_inactive(TrafficPrice) self._get_methods().get_traffic_prices(self) @log_start_end def get_ipv4_prices(self): + self.set_table_rows_inactive(Ipv4Price) self._get_methods().get_ipv4_prices(self) - def set_session(self, session): - self._session = session - - def merge_dependent(self, obj): - if self._session: - self._session.merge(obj) - class Datacenter(ScModel, table=True): id: str = Field(primary_key=True) @@ -547,6 +586,7 @@ class PriceTier(Json): # helper classes to inherit for most commonly used fields +# TODO rewrite above classes using helper classes as well class HasVendorPK(ScModel): @@ -573,7 +613,7 @@ class HasTraffic(ScModel): traffic_id: str = Field(foreign_key="traffic.id", primary_key=True) -class HasPriceFields(ScModel): +class HasPriceFieldsBase(ScModel): unit: PriceUnit # set to max price if tiered price: float @@ -584,6 +624,10 @@ class HasPriceFields(ScModel): currency: str = "USD" +class HasPriceFields(HasStatus, HasPriceFieldsBase): + pass + + class ServerPriceExtraFields(ScModel): operating_system: str allocation: Allocation = Allocation.ONDEMAND