Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DEV-15 rich progress bars #8

Merged
merged 22 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ coverage.xml

# Sphinx documentation
docs/_build/

# Temp linters
flycheck_*
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"pydantic",
"pydantic_extra_types",
"pycountry",
"rich",
"sqlmodel",
"typer",
]
Expand Down
103 changes: 66 additions & 37 deletions src/sc_crawler/cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
from datetime import timedelta
from datetime import datetime, timedelta
from enum import Enum
from json import dumps
from typing import List

import typer
from cachier import set_default_params
from rich.live import Live
from rich.text import Text
from sqlmodel import Session, SQLModel, create_engine
from typing_extensions import Annotated

from . import vendors as vendors_module
from .logger import logger
from .logger import ProgressPanel, ScRichHandler, VendorProgressTracker, logger
from .lookup import compliance_frameworks, countries
from .schemas import Vendor
from .utils import hash_database
Expand Down Expand Up @@ -114,10 +116,8 @@ def custom_serializer(x):
)

# enable logging
channel = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s/%(module)s:%(funcName)s - %(levelname)s - %(message)s"
)
channel = ScRichHandler()
formatter = logging.Formatter("%(message)s")
channel.setFormatter(formatter)
logger.setLevel(log_level.value)
logger.addHandler(channel)
Expand All @@ -135,37 +135,66 @@ def custom_serializer(x):

engine = create_engine(connection_string, json_serializer=custom_serializer)
SQLModel.metadata.create_all(engine)
with Session(engine) as session:
# add/merge static objects to database
for compliance_framework in compliance_frameworks.values():
session.merge(compliance_framework)
for country in countries.values():
session.merge(country)
# get data for each vendor and then add/merge to database
for vendor in vendors:
logger.info("Starting to collect data from vendor: " + vendor.id)
vendor = session.merge(vendor)
vendor.set_session(session)
if Tables.compliance_frameworks in update_table:
vendor.inventory_compliance_frameworks()
if Tables.datacenters in update_table:
vendor.inventory_datacenters()
if Tables.zones in update_table:
vendor.inventory_zones()
if Tables.servers in update_table:
vendor.inventory_servers()
if Tables.server_prices in update_table:
vendor.inventory_server_prices()
if Tables.server_prices_spot in update_table:
vendor.inventory_server_prices_spot()
if Tables.storage_prices in update_table:
vendor.inventory_storage_prices()
if Tables.traffic_prices in update_table:
vendor.inventory_traffic_prices()
if Tables.ipv4_prices in update_table:
vendor.inventory_ipv4_prices()
session.merge(vendor)
session.commit()

pbars = ProgressPanel()
with Live(pbars.panels):
# show CLI arguments in the Metadata panel
pbars.metadata.append(Text("Update target(s): ", style="bold"))
pbars.metadata.append(Text(", ".join([x.value for x in update_table]) + "\n"))
pbars.metadata.append(Text("Connection type: ", style="bold"))
pbars.metadata.append(Text(connection_string.split(":")[0]))
pbars.metadata.append(Text(" Cache: ", style="bold"))
if cache:
pbars.metadata.append(Text("Enabled (" + str(cache_ttl) + "m)"))
else:
pbars.metadata.append(Text("Disabled"))
pbars.metadata.append(Text(" Time: ", style="bold"))
pbars.metadata.append(Text(str(datetime.now())))

with Session(engine) as session:
# add/merge static objects to database
for compliance_framework in compliance_frameworks.values():
session.merge(compliance_framework)
logger.info("%d Compliance Frameworks synced." % len(compliance_frameworks))
for country in countries.values():
session.merge(country)
logger.info("%d Countries synced." % len(countries))
# get data for each vendor and then add/merge to database
# TODO each vendor should open its own session and run in parallel
for vendor in vendors:
logger.info("Starting to collect data from vendor: " + vendor.id)
vendor = session.merge(vendor)
# register session to the Vendor so that dependen objects can auto-merge
vendor.session = session
# register progress bars so that helpers can update
vendor.progress_tracker = VendorProgressTracker(
vendor=vendor, progress_panel=pbars
)
vendor.progress_tracker.start_vendor(n=len(update_table))
if Tables.compliance_frameworks in update_table:
vendor.inventory_compliance_frameworks()
if Tables.datacenters in update_table:
vendor.inventory_datacenters()
if Tables.zones in update_table:
vendor.inventory_zones()
if Tables.servers in update_table:
vendor.inventory_servers()
if Tables.server_prices in update_table:
vendor.inventory_server_prices()
if Tables.server_prices_spot in update_table:
vendor.inventory_server_prices_spot()
if Tables.storage_prices in update_table:
vendor.inventory_storage_prices()
if Tables.traffic_prices in update_table:
vendor.inventory_traffic_prices()
if Tables.ipv4_prices in update_table:
vendor.inventory_ipv4_prices()
# reset current step name
vendor.progress_tracker.update_vendor(step="")
session.merge(vendor)
session.commit()

pbars.metadata.append(Text(" - " + str(datetime.now())))


if __name__ == "__main__":
Expand Down
226 changes: 226 additions & 0 deletions src/sc_crawler/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,26 @@
from __future__ import annotations

import logging
from datetime import datetime
from importlib.metadata import version
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional

from rich.console import ConsoleRenderable, Group
from rich.logging import RichHandler
from rich.panel import Panel
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
TaskID,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich.text import Text
from rich.traceback import Traceback

logger = logging.getLogger("sc_crawler")
logger.addHandler(logging.NullHandler())
Expand All @@ -8,14 +30,218 @@ def log_start_end(func):
"""Log the start and end of the decorated function."""

def wrap(*args, **kwargs):
# log start of the step
try:
self = args[0]
fname = f"{self.id}/{func.__name__}"
except Exception:
fname = func.__name__
logger.debug("Starting %s", fname)

# update Vendor's progress bar with the step name
try:
self.progress_tracker.update_vendor(
# drop `inventory_` prefix and prettify
step=func.__name__[10:].replace("_", " ")
)
except Exception:
logger.error("Cannot update step name in the Vendor's progress bar.")

# actually run step
result = func(*args, **kwargs)

# increment Vendor's progress bar
self.progress_tracker.advance_vendor()

# log end of the step and return
logger.debug("Finished %s", fname)
return result

return wrap


# https://github.com/Textualize/rich/issues/1532#issuecomment-1062431265
class ScRichHandler(RichHandler):
"""Extend RichHandler with function name logged in the right column."""

def render(
self,
*,
record: logging.LogRecord,
traceback: Optional[Traceback],
message_renderable: "ConsoleRenderable",
):
path = Path(record.pathname).name + ":" + record.funcName
level = self.get_level_text(record)
time_format = None if self.formatter is None else self.formatter.datefmt
log_time = datetime.fromtimestamp(record.created)

log_renderable = self._log_render(
self.console,
[message_renderable] if not traceback else [message_renderable, traceback],
log_time=log_time,
time_format=time_format,
level=level,
path=path,
line_no=record.lineno,
link_path=record.pathname if self.enable_link_path else None,
)
return log_renderable


class ProgressPanel:
vendors: Progress = Progress(
TimeElapsedColumn(),
TextColumn("{task.description}"),
BarColumn(),
TextColumn("({task.completed} of {task.total} steps): {task.fields[step]}"),
expand=False,
)
tasks: Progress = Progress(
TimeElapsedColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
MofNCompleteColumn(),
TimeRemainingColumn(),
expand=False,
transient=True,
)
metadata: Text = Text(justify="left")
panels: Table = Table.grid(padding=1)

def __init__(self, *args, **kwargs):
self.panels.add_row(
Group(
Panel(
self.metadata,
title="SC Crawler v" + version("sc_crawler"),
title_align="left",
),
Panel(
self.vendors,
title="Vendors",
title_align="left",
),
),
Panel(
self.tasks,
title="Running tasks",
title_align="left",
expand=False,
),
)

def add_task(self, description: str, n: int):
return self.tasks.add_task(description, total=n)

def add_vendor(self, vendor_name: str, steps: int):
return self.vendors.add_task(vendor_name, total=steps)


if TYPE_CHECKING:
from .schemas import Vendor


class VendorProgressTracker:
"""Tracing the progress of the vendor's inventory."""

vendor: Vendor
progress_panel: ProgressPanel
# reexport Progress attrubutes of the ProgressPanel
vendors: Progress
tasks: Progress
metadata: Progress
task_ids: List[TaskID] = []

def __init__(self, vendor: Vendor, progress_panel: ProgressPanel):
self.vendor = vendor
self.progress_panel = progress_panel
self.vendors = progress_panel.vendors
self.tasks = progress_panel.tasks
self.metadata = progress_panel.metadata

def start_vendor(self, n: int) -> TaskID:
"""Starts a progress bar for the Vendor's steps.

Args:
n: Overall number of steps to show in the progress bar.

Returns:
TaskId: The progress bar's identifier to be referenced in future updates.
"""
return self.vendors.add_task(self.vendor.name, total=n, step="")

def advance_vendor(self, by: int = 1) -> None:
"""Increment the number of finished steps.

Args:
by: Number of steps to advance.
"""
self.vendors.update(self.vendors.task_ids[0], advance=by)

def update_vendor(self, **kwargs) -> None:
"""Update the vendor's progress bar.

Useful fields:
- `step`: Name of the currently running step to be shown on the progress bar.
"""
self.vendors.update(self.vendors.task_ids[0], **kwargs)

def start_task(self, name: str, n: int) -> TaskID:
"""Starts a progress bar in the list of current jobs.

Besides returning the `TaskID`, it will also register in `self.tasks.task_ids`
as the last task, which will be the default value for future `advance_task`,
`hide_task` etc calls. The latter will remove the `TaskID` from the `task_ids`.

Args:
name: Name to show in front of the progress bar. Will be prefixed by Vendor's name.
n: Overall number of steps to show in the progress bar.

Returns:
TaskId: The progress bar's identifier to be referenced in future updates.
"""
self.task_ids.append(
self.tasks.add_task(self.vendor.name + ": " + name, total=n)
)
return self.last_task()

def last_task(self) -> TaskID:
"""Returh the last registered TaskID."""
return self.task_ids[-1]

def advance_task(self, task_id: Optional[TaskID] = None, by: int = 1):
"""Increment the number of finished steps.

Args:
task_id: The progress bar's identifier returned by `start_task`.
Defaults to the most recently created task.
by: Number of steps to advance.
"""

self.tasks.update(task_id or self.last_task(), advance=by)

def update_task(self, task_id: Optional[TaskID] = None, **kwargs) -> None:
"""Update the task's progress bar.

Args:
task_id: The progress bar's identifier returned by `start_task`.
Defaults to the most recently created task.

Keyword Args:
step: Name of the currently running step to be shown on the progress bar.

See `Progress.update` for further keyword arguments:
https://rich.readthedocs.io/en/stable/reference/progress.html#rich.progress.Progress.update
"""
self.tasks.update(task_id or self.last_task(), **kwargs)

def hide_task(self, task_id: Optional[TaskID] = None):
"""Hide a task from the list of progress bars.

Args:
task_id: The progress bar's identifier returned by `start_task`.
Defaults to the most recently created task.
"""
self.tasks.update(task_id or self.last_task(), visible=False)
self.task_ids.pop()
Loading
Loading