Skip to content

Commit

Permalink
Use full package/module path for the class name
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Jun 27, 2024
1 parent 7aba486 commit daa0b5d
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 34 deletions.
71 changes: 54 additions & 17 deletions src/umlizer/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,63 @@ class ClassDef:
)


def get_full_class_path(cls: Type[Any], root_path: Path) -> str:
"""
Get the full package path for a given class, including parent packages.
Parameters
----------
cls : Type[Any]
The class to inspect.
root_path : Path
The root path of the project to determine the full package path.
Returns
-------
str
The full package path of the class.
"""
module = cls.__module__
module_file = importlib.import_module(module).__file__

if module_file is None:
raise ValueError(f'The module file {module} is invalid.')

root_path_str = str(root_path)

if not module_file.startswith(root_path_str):
raise ValueError(
f'The module file {module_file} is not within the '
f'root path {root_path}'
)

relative_path = os.path.relpath(module_file, root_path_str)
package_path = os.path.splitext(relative_path)[0].replace(os.sep, '.')

return f'{package_path}.{cls.__qualname__}'


def _get_fullname(entity: Type[Any]) -> str:
"""
Get the fully qualified name of a given entity.
Parameters
----------
entity : types.ModuleType
entity : Type[Any]
The entity for which the full name is required.
Returns
-------
str
Fully qualified name of the entity.
"""
if hasattr(entity, '__module__'):
return f'{entity.__module__}.{entity.__name__}'
elif hasattr(entity, '__name__'):
return entity.__name__
module = getattr(entity, '__module__', '')
qualname = getattr(entity, '__qualname__', str(entity))

if module:
return module + '.' + qualname

return str(entity)
return qualname


def _get_method_annotation(method: types.FunctionType) -> dict[str, str]:
Expand Down Expand Up @@ -99,9 +136,9 @@ def _get_dataclass_structure(

def _get_base_classes(klass: Type[Any]) -> list[Type[Any]]:
return [
c
for c in klass.__mro__
if c.__name__ not in ('object', klass.__name__)
base_class
for base_class in getattr(klass, '__bases__', [])
if base_class.__name__ != 'object'
]


Expand Down Expand Up @@ -160,19 +197,16 @@ def _get_classic_class_structure(klass: Type[Any]) -> ClassDef:
value = klass_anno.get(k, 'UNKNOWN')
fields[k] = getattr(value, '__value__', str(value))

if not fields:
# Extract attributes from the `__init__` method if defined there.
fields = _get_init_attributes(klass)
# Extract attributes from the `__init__` method if defined there.
fields.update(_get_init_attributes(klass))

return ClassDef(
fields=fields,
methods=_methods,
)


def _get_class_structure(
klass: Type[Any],
) -> ClassDef:
def _get_class_structure(klass: Type[Any], root_path: Path) -> ClassDef:
if dataclasses.is_dataclass(klass):
class_struct = _get_dataclass_structure(klass)
elif inspect.isclass(klass):
Expand All @@ -181,7 +215,7 @@ def _get_class_structure(
raise Exception('The given class is not actually a class.')

class_struct.module = klass.__module__
class_struct.name = _get_fullname(klass)
class_struct.name = get_full_class_path(klass, root_path)

class_struct.bases = []
for ref_class in _get_base_classes(klass):
Expand Down Expand Up @@ -356,9 +390,12 @@ def load_classes_definition(
module_files.append(path_str)

for file_path in module_files:
if verbose:
print(file_path)

classes_list.extend(_get_classes_from_module(file_path))

return [_get_class_structure(cls) for cls in classes_list]
return [_get_class_structure(cls, source) for cls in classes_list]


def dict_to_classdef(classes_list: list[dict[str, Any]]) -> list[ClassDef]:
Expand Down
51 changes: 51 additions & 0 deletions tests/ecommerce/offering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC


class Offering(ABC):
def __init__(self, offering_id: int, name: str) -> None:
self.offering_id: int = offering_id
self.name: str = name


class Product(Offering):
"""Represents a product in the e-commerce system."""

def __init__(
self, product_id: int, name: str, price: float, stock: int
) -> None:
super().__init__(product_id, name)
self.price: float = price
self.stock: int = stock

def update_stock(self, amount: int) -> None:
"""Updates the stock quantity for the product."""
self.stock += amount

def get_product_info(self) -> str:
"""Returns the product's information."""
return (
f'Product ID: {self.product_id}, Name: {self.name}, '
f'Price: ${self.price}, Stock: {self.stock}'
)


class Service(Offering):
"""Represents a service in the e-commerce system."""

def __init__(
self, service_id: int, name: str, rate: float, duration: int
) -> None:
super().__init__(service_id, name)
self.rate: float = rate
self.duration: int = duration # duration in minutes

def update_duration(self, additional_minutes: int) -> None:
"""Updates the duration for the service."""
self.duration += additional_minutes

def get_service_info(self) -> str:
"""Returns the service's information."""
return (
f'Service ID: {self.service_id}, Name: {self.name}, '
f'Rate: ${self.rate}/hr, Duration: {self.duration} minutes'
)
2 changes: 1 addition & 1 deletion tests/ecommerce/order.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from typing import List
from user import User, Address
from product import Product
from offering import Product


class Order:
Expand Down
16 changes: 0 additions & 16 deletions tests/ecommerce/product.py

This file was deleted.

0 comments on commit daa0b5d

Please sign in to comment.