Skip to content

Commit

Permalink
Merge pull request #8 from SpareCores/DEV-15
Browse files Browse the repository at this point in the history
DEV-15 rich progress bars
  • Loading branch information
daroczig authored Feb 26, 2024
2 parents 0f4c517 + 88c6ef2 commit 4400845
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ coverage.xml

# Sphinx documentation
docs/_build/

# Temp linters
flycheck_*
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"pydantic",
"pydantic_extra_types",
"pycountry",
"rich",
"sqlmodel",
"typer",
]
Expand Down
103 changes: 66 additions & 37 deletions src/sc_crawler/cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
from datetime import timedelta
from datetime import datetime, timedelta
from enum import Enum
from json import dumps
from typing import List

import typer
from cachier import set_default_params
from rich.live import Live
from rich.text import Text
from sqlmodel import Session, SQLModel, create_engine
from typing_extensions import Annotated

from . import vendors as vendors_module
from .logger import logger
from .logger import ProgressPanel, ScRichHandler, VendorProgressTracker, logger
from .lookup import compliance_frameworks, countries
from .schemas import Vendor
from .utils import hash_database
Expand Down Expand Up @@ -114,10 +116,8 @@ def custom_serializer(x):
)

# enable logging
channel = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s/%(module)s:%(funcName)s - %(levelname)s - %(message)s"
)
channel = ScRichHandler()
formatter = logging.Formatter("%(message)s")
channel.setFormatter(formatter)
logger.setLevel(log_level.value)
logger.addHandler(channel)
Expand All @@ -135,37 +135,66 @@ def custom_serializer(x):

engine = create_engine(connection_string, json_serializer=custom_serializer)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
# add/merge static objects to database
for compliance_framework in compliance_frameworks.values():
session.merge(compliance_framework)
for country in countries.values():
session.merge(country)
# get data for each vendor and then add/merge to database
for vendor in vendors:
logger.info("Starting to collect data from vendor: " + vendor.id)
vendor = session.merge(vendor)
vendor.set_session(session)
if Tables.compliance_frameworks in update_table:
vendor.inventory_compliance_frameworks()
if Tables.datacenters in update_table:
vendor.inventory_datacenters()
if Tables.zones in update_table:
vendor.inventory_zones()
if Tables.servers in update_table:
vendor.inventory_servers()
if Tables.server_prices in update_table:
vendor.inventory_server_prices()
if Tables.server_prices_spot in update_table:
vendor.inventory_server_prices_spot()
if Tables.storage_prices in update_table:
vendor.inventory_storage_prices()
if Tables.traffic_prices in update_table:
vendor.inventory_traffic_prices()
if Tables.ipv4_prices in update_table:
vendor.inventory_ipv4_prices()
session.merge(vendor)
session.commit()

pbars = ProgressPanel()
with Live(pbars.panels):
# show CLI arguments in the Metadata panel
pbars.metadata.append(Text("Update target(s): ", style="bold"))
pbars.metadata.append(Text(", ".join([x.value for x in update_table]) + "\n"))
pbars.metadata.append(Text("Connection type: ", style="bold"))
pbars.metadata.append(Text(connection_string.split(":")[0]))
pbars.metadata.append(Text(" Cache: ", style="bold"))
if cache:
pbars.metadata.append(Text("Enabled (" + str(cache_ttl) + "m)"))
else:
pbars.metadata.append(Text("Disabled"))
pbars.metadata.append(Text(" Time: ", style="bold"))
pbars.metadata.append(Text(str(datetime.now())))

with Session(engine) as session:
# add/merge static objects to database
for compliance_framework in compliance_frameworks.values():
session.merge(compliance_framework)
logger.info("%d Compliance Frameworks synced." % len(compliance_frameworks))
for country in countries.values():
session.merge(country)
logger.info("%d Countries synced." % len(countries))
# get data for each vendor and then add/merge to database
# TODO each vendor should open its own session and run in parallel
for vendor in vendors:
logger.info("Starting to collect data from vendor: " + vendor.id)
vendor = session.merge(vendor)
# register session to the Vendor so that dependen objects can auto-merge
vendor.session = session
# register progress bars so that helpers can update
vendor.progress_tracker = VendorProgressTracker(
vendor=vendor, progress_panel=pbars
)
vendor.progress_tracker.start_vendor(n=len(update_table))
if Tables.compliance_frameworks in update_table:
vendor.inventory_compliance_frameworks()
if Tables.datacenters in update_table:
vendor.inventory_datacenters()
if Tables.zones in update_table:
vendor.inventory_zones()
if Tables.servers in update_table:
vendor.inventory_servers()
if Tables.server_prices in update_table:
vendor.inventory_server_prices()
if Tables.server_prices_spot in update_table:
vendor.inventory_server_prices_spot()
if Tables.storage_prices in update_table:
vendor.inventory_storage_prices()
if Tables.traffic_prices in update_table:
vendor.inventory_traffic_prices()
if Tables.ipv4_prices in update_table:
vendor.inventory_ipv4_prices()
# reset current step name
vendor.progress_tracker.update_vendor(step="")
session.merge(vendor)
session.commit()

pbars.metadata.append(Text(" - " + str(datetime.now())))


if __name__ == "__main__":
Expand Down
226 changes: 226 additions & 0 deletions src/sc_crawler/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
from __future__ import annotations

import logging
from datetime import datetime
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional

from rich.console import ConsoleRenderable, Group
from rich.logging import RichHandler
from rich.panel import Panel
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
TaskID,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.text import Text
from rich.traceback import Traceback

logger = logging.getLogger("sc_crawler")
logger.addHandler(logging.NullHandler())
Expand All @@ -8,14 +30,218 @@ def log_start_end(func):
"""Log the start and end of the decorated function."""

def wrap(*args, **kwargs):
# log start of the step
try:
self = args[0]
fname = f"{self.id}/{func.__name__}"
except Exception:
fname = func.__name__
logger.debug("Starting %s", fname)

# update Vendor's progress bar with the step name
try:
self.progress_tracker.update_vendor(
# drop `inventory_` prefix and prettify
step=func.__name__[10:].replace("_", " ")
)
except Exception:
logger.error("Cannot update step name in the Vendor's progress bar.")

# actually run step
result = func(*args, **kwargs)

# increment Vendor's progress bar
self.progress_tracker.advance_vendor()

# log end of the step and return
logger.debug("Finished %s", fname)
return result

return wrap


# https://github.com/Textualize/rich/issues/1532#issuecomment-1062431265
class ScRichHandler(RichHandler):
"""Extend RichHandler with function name logged in the right column."""

def render(
self,
*,
record: logging.LogRecord,
traceback: Optional[Traceback],
message_renderable: "ConsoleRenderable",
):
path = Path(record.pathname).name + ":" + record.funcName
level = self.get_level_text(record)
time_format = None if self.formatter is None else self.formatter.datefmt
log_time = datetime.fromtimestamp(record.created)

log_renderable = self._log_render(
self.console,
[message_renderable] if not traceback else [message_renderable, traceback],
log_time=log_time,
time_format=time_format,
level=level,
path=path,
line_no=record.lineno,
link_path=record.pathname if self.enable_link_path else None,
)
return log_renderable


class ProgressPanel:
vendors: Progress = Progress(
TimeElapsedColumn(),
TextColumn("{task.description}"),
BarColumn(),
TextColumn("({task.completed} of {task.total} steps): {task.fields[step]}"),
expand=False,
)
tasks: Progress = Progress(
TimeElapsedColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeRemainingColumn(),
expand=False,
transient=True,
)
metadata: Text = Text(justify="left")
panels: Table = Table.grid(padding=1)

def __init__(self, *args, **kwargs):
self.panels.add_row(
Group(
Panel(
self.metadata,
title="SC Crawler v" + version("sc_crawler"),
title_align="left",
),
Panel(
self.vendors,
title="Vendors",
title_align="left",
),
),
Panel(
self.tasks,
title="Running tasks",
title_align="left",
expand=False,
),
)

def add_task(self, description: str, n: int):
return self.tasks.add_task(description, total=n)

def add_vendor(self, vendor_name: str, steps: int):
return self.vendors.add_task(vendor_name, total=steps)


if TYPE_CHECKING:
from .schemas import Vendor


class VendorProgressTracker:
"""Tracing the progress of the vendor's inventory."""

vendor: Vendor
progress_panel: ProgressPanel
# reexport Progress attrubutes of the ProgressPanel
vendors: Progress
tasks: Progress
metadata: Progress
task_ids: List[TaskID] = []

def __init__(self, vendor: Vendor, progress_panel: ProgressPanel):
self.vendor = vendor
self.progress_panel = progress_panel
self.vendors = progress_panel.vendors
self.tasks = progress_panel.tasks
self.metadata = progress_panel.metadata

def start_vendor(self, n: int) -> TaskID:
"""Starts a progress bar for the Vendor's steps.
Args:
n: Overall number of steps to show in the progress bar.
Returns:
TaskId: The progress bar's identifier to be referenced in future updates.
"""
return self.vendors.add_task(self.vendor.name, total=n, step="")

def advance_vendor(self, by: int = 1) -> None:
"""Increment the number of finished steps.
Args:
by: Number of steps to advance.
"""
self.vendors.update(self.vendors.task_ids[0], advance=by)

def update_vendor(self, **kwargs) -> None:
"""Update the vendor's progress bar.
Useful fields:
- `step`: Name of the currently running step to be shown on the progress bar.
"""
self.vendors.update(self.vendors.task_ids[0], **kwargs)

def start_task(self, name: str, n: int) -> TaskID:
"""Starts a progress bar in the list of current jobs.
Besides returning the `TaskID`, it will also register in `self.tasks.task_ids`
as the last task, which will be the default value for future `advance_task`,
`hide_task` etc calls. The latter will remove the `TaskID` from the `task_ids`.
Args:
name: Name to show in front of the progress bar. Will be prefixed by Vendor's name.
n: Overall number of steps to show in the progress bar.
Returns:
TaskId: The progress bar's identifier to be referenced in future updates.
"""
self.task_ids.append(
self.tasks.add_task(self.vendor.name + ": " + name, total=n)
)
return self.last_task()

def last_task(self) -> TaskID:
"""Returh the last registered TaskID."""
return self.task_ids[-1]

def advance_task(self, task_id: Optional[TaskID] = None, by: int = 1):
"""Increment the number of finished steps.
Args:
task_id: The progress bar's identifier returned by `start_task`.
Defaults to the most recently created task.
by: Number of steps to advance.
"""

self.tasks.update(task_id or self.last_task(), advance=by)

def update_task(self, task_id: Optional[TaskID] = None, **kwargs) -> None:
"""Update the task's progress bar.
Args:
task_id: The progress bar's identifier returned by `start_task`.
Defaults to the most recently created task.
Keyword Args:
step: Name of the currently running step to be shown on the progress bar.
See `Progress.update` for further keyword arguments:
https://rich.readthedocs.io/en/stable/reference/progress.html#rich.progress.Progress.update
"""
self.tasks.update(task_id or self.last_task(), **kwargs)

def hide_task(self, task_id: Optional[TaskID] = None):
"""Hide a task from the list of progress bars.
Args:
task_id: The progress bar's identifier returned by `start_task`.
Defaults to the most recently created task.
"""
self.tasks.update(task_id or self.last_task(), visible=False)
self.task_ids.pop()
Loading

0 comments on commit 4400845

Please sign in to comment.