Skip to content

Commit

Permalink
make sure all fetched records have STATUS and reset to INACTIVE befor…
Browse files Browse the repository at this point in the history
…e pull
  • Loading branch information
daroczig committed Feb 22, 2024
1 parent c8b142f commit 98439f9
Showing 1 changed file with 59 additions and 15 deletions.
74 changes: 59 additions & 15 deletions src/sc_crawler/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ImportString,
PrivateAttr,
)
from sqlalchemy import DateTime
from sqlalchemy import DateTime, update
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import declared_attr
from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, select
Expand Down Expand Up @@ -110,18 +110,25 @@ def __init__(self, *args, **kwargs):
self.vendor.merge_dependent(self)


class Status(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"


class HasStatus(ScModel):
status: Status = Field(
default=Status.ACTIVE,
description="Status of the resource (active or inactive).",
)


class Json(BaseModel):
"""Custom base SQLModel class that supports dumping as JSON."""

def __json__(self):
return self.model_dump()


class Status(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"


class Country(ScModel, table=True):
"""Country and continent mapping."""

Expand All @@ -136,13 +143,15 @@ class Country(ScModel, table=True):
datacenters: List["Datacenter"] = Relationship(back_populates="country")


class VendorComplianceLink(ScModel, table=True):
class VendorComplianceLinkBase(ScModel):
vendor_id: str = Field(foreign_key="vendor.id", primary_key=True)
compliance_framework_id: str = Field(
foreign_key="compliance_framework.id", primary_key=True
)
comment: Optional[str] = None


class VendorComplianceLink(HasStatus, VendorComplianceLinkBase, table=True):
vendor: "Vendor" = Relationship(back_populates="compliance_framework_links")
compliance_framework: "ComplianceFramework" = Relationship(
back_populates="vendor_links"
Expand Down Expand Up @@ -271,55 +280,85 @@ def _get_methods(self):
) from exc
return self._methods

def set_session(self, session):
"""Attach a SQLModel session to use for merging dependent objects into the database."""
self._session = session

def merge_dependent(self, obj):
"""Merge an object into the Vendor's SQLModel session (when available)."""
if self._session:
self._session.merge(obj)

def set_table_rows_inactive(self, model: str, *args) -> None:
"""Set this vendor's records to INACTIVE in a table
Positional arguments can be used to pass further filters
(besides the default model.vendor_id filter) referencing the
model object with SQLModel syntax, e.g.
>>> aws.set_table_rows_inactive(ServerPrice, ServerPrice.price < 10)
"""
if self._session:
query = update(model).where(model.vendor_id == self.id)
for arg in args:
query = query.where(arg)
self._session.execute(query.values(status=Status.INACTIVE))

@log_start_end
def get_compliance_frameworks(self):
"""Get the vendor's all compliance frameworks."""
self.set_table_rows_inactive(VendorComplianceLink)
self._get_methods().get_compliance_frameworks(self)

@log_start_end
def get_datacenters(self):
"""Get the vendor's all datacenters."""
self.set_table_rows_inactive(Datacenter)
self._get_methods().get_datacenters(self)

@log_start_end
def get_zones(self):
"""Get all the zones in the vendor's datacenters."""
self.set_table_rows_inactive(Zone)
self._get_methods().get_zones(self)

@log_start_end
def get_servers(self):
"""Get the vendor's all server types."""
self.set_table_rows_inactive(Server)
self._get_methods().get_servers(self)

@log_start_end
def get_server_prices(self):
"""Get the current standard/ondemand/reserved prices of all server types."""
self.set_table_rows_inactive(
ServerPrice, ServerPrice.allocation != Allocation.SPOT
)
self._get_methods().get_server_prices(self)

@log_start_end
def get_server_prices_spot(self):
"""Get the current spot prices of all server types."""
self.set_table_rows_inactive(
ServerPrice, ServerPrice.allocation == Allocation.SPOT
)
self._get_methods().get_server_prices_spot(self)

@log_start_end
def get_storage_prices(self):
self.set_table_rows_inactive(StoragePrice)
self._get_methods().get_storage_prices(self)

@log_start_end
def get_traffic_prices(self):
self.set_table_rows_inactive(TrafficPrice)
self._get_methods().get_traffic_prices(self)

@log_start_end
def get_ipv4_prices(self):
self.set_table_rows_inactive(Ipv4Price)
self._get_methods().get_ipv4_prices(self)

def set_session(self, session):
self._session = session

def merge_dependent(self, obj):
if self._session:
self._session.merge(obj)


class Datacenter(ScModel, table=True):
id: str = Field(primary_key=True)
Expand Down Expand Up @@ -547,6 +586,7 @@ class PriceTier(Json):


# helper classes to inherit for most commonly used fields
# TODO rewrite above classes using helper classes as well


class HasVendorPK(ScModel):
Expand All @@ -573,7 +613,7 @@ class HasTraffic(ScModel):
traffic_id: str = Field(foreign_key="traffic.id", primary_key=True)


class HasPriceFields(ScModel):
class HasPriceFieldsBase(ScModel):
unit: PriceUnit
# set to max price if tiered
price: float
Expand All @@ -584,6 +624,10 @@ class HasPriceFields(ScModel):
currency: str = "USD"


class HasPriceFields(HasStatus, HasPriceFieldsBase):
pass


class ServerPriceExtraFields(ScModel):
operating_system: str
allocation: Allocation = Allocation.ONDEMAND
Expand Down

0 comments on commit 98439f9

Please sign in to comment.