diff --git a/src/umlizer/inspector.py b/src/umlizer/inspector.py index e343835..444a9f5 100644 --- a/src/umlizer/inspector.py +++ b/src/umlizer/inspector.py @@ -31,13 +31,49 @@ class ClassDef: ) +def get_full_class_path(cls: Type[Any], root_path: Path) -> str: + """ + Get the full package path for a given class, including parent packages. + + Parameters + ---------- + cls : Type[Any] + The class to inspect. + root_path : Path + The root path of the project to determine the full package path. + + Returns + ------- + str + The full package path of the class. + """ + module = cls.__module__ + module_file = importlib.import_module(module).__file__ + + if module_file is None: + raise ValueError(f'The module file {module} is invalid.') + + root_path_str = str(root_path) + + if not module_file.startswith(root_path_str): + raise ValueError( + f'The module file {module_file} is not within the ' + f'root path {root_path}' + ) + + relative_path = os.path.relpath(module_file, root_path_str) + package_path = os.path.splitext(relative_path)[0].replace(os.sep, '.') + + return f'{package_path}.{cls.__qualname__}' + + def _get_fullname(entity: Type[Any]) -> str: """ Get the fully qualified name of a given entity. Parameters ---------- - entity : types.ModuleType + entity : Type[Any] The entity for which the full name is required. Returns @@ -45,12 +81,13 @@ def _get_fullname(entity: Type[Any]) -> str: str Fully qualified name of the entity. """ - if hasattr(entity, '__module__'): - return f'{entity.__module__}.{entity.__name__}' - elif hasattr(entity, '__name__'): - return entity.__name__ + module = getattr(entity, '__module__', '') + qualname = getattr(entity, '__qualname__', str(entity)) + + if module: + return module + '.' + qualname - return str(entity) + return qualname def _get_method_annotation(method: types.FunctionType) -> dict[str, str]: @@ -99,9 +136,9 @@ def _get_dataclass_structure( def _get_base_classes(klass: Type[Any]) -> list[Type[Any]]: return [ - c - for c in klass.__mro__ - if c.__name__ not in ('object', klass.__name__) + base_class + for base_class in getattr(klass, '__bases__', []) + if base_class.__name__ != 'object' ] @@ -160,9 +197,8 @@ def _get_classic_class_structure(klass: Type[Any]) -> ClassDef: value = klass_anno.get(k, 'UNKNOWN') fields[k] = getattr(value, '__value__', str(value)) - if not fields: - # Extract attributes from the `__init__` method if defined there. - fields = _get_init_attributes(klass) + # Extract attributes from the `__init__` method if defined there. + fields.update(_get_init_attributes(klass)) return ClassDef( fields=fields, @@ -170,9 +206,7 @@ def _get_classic_class_structure(klass: Type[Any]) -> ClassDef: ) -def _get_class_structure( - klass: Type[Any], -) -> ClassDef: +def _get_class_structure(klass: Type[Any], root_path: Path) -> ClassDef: if dataclasses.is_dataclass(klass): class_struct = _get_dataclass_structure(klass) elif inspect.isclass(klass): @@ -181,7 +215,7 @@ def _get_class_structure( raise Exception('The given class is not actually a class.') class_struct.module = klass.__module__ - class_struct.name = _get_fullname(klass) + class_struct.name = get_full_class_path(klass, root_path) class_struct.bases = [] for ref_class in _get_base_classes(klass): @@ -356,9 +390,12 @@ def load_classes_definition( module_files.append(path_str) for file_path in module_files: + if verbose: + print(file_path) + classes_list.extend(_get_classes_from_module(file_path)) - return [_get_class_structure(cls) for cls in classes_list] + return [_get_class_structure(cls, source) for cls in classes_list] def dict_to_classdef(classes_list: list[dict[str, Any]]) -> list[ClassDef]: diff --git a/tests/ecommerce/offering.py b/tests/ecommerce/offering.py new file mode 100644 index 0000000..0826008 --- /dev/null +++ b/tests/ecommerce/offering.py @@ -0,0 +1,51 @@ +from abc import ABC + + +class Offering(ABC): + def __init__(self, offering_id: int, name: str) -> None: + self.offering_id: int = offering_id + self.name: str = name + + +class Product(Offering): + """Represents a product in the e-commerce system.""" + + def __init__( + self, product_id: int, name: str, price: float, stock: int + ) -> None: + super().__init__(product_id, name) + self.price: float = price + self.stock: int = stock + + def update_stock(self, amount: int) -> None: + """Updates the stock quantity for the product.""" + self.stock += amount + + def get_product_info(self) -> str: + """Returns the product's information.""" + return ( + f'Product ID: {self.product_id}, Name: {self.name}, ' + f'Price: ${self.price}, Stock: {self.stock}' + ) + + +class Service(Offering): + """Represents a service in the e-commerce system.""" + + def __init__( + self, service_id: int, name: str, rate: float, duration: int + ) -> None: + super().__init__(service_id, name) + self.rate: float = rate + self.duration: int = duration # duration in minutes + + def update_duration(self, additional_minutes: int) -> None: + """Updates the duration for the service.""" + self.duration += additional_minutes + + def get_service_info(self) -> str: + """Returns the service's information.""" + return ( + f'Service ID: {self.service_id}, Name: {self.name}, ' + f'Rate: ${self.rate}/hr, Duration: {self.duration} minutes' + ) diff --git a/tests/ecommerce/order.py b/tests/ecommerce/order.py index 00df81b..f50e9ee 100644 --- a/tests/ecommerce/order.py +++ b/tests/ecommerce/order.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List from user import User, Address -from product import Product +from offering import Product class Order: diff --git a/tests/ecommerce/product.py b/tests/ecommerce/product.py deleted file mode 100644 index 771db37..0000000 --- a/tests/ecommerce/product.py +++ /dev/null @@ -1,16 +0,0 @@ -class Product: - """Represents a product in the e-commerce system.""" - - def __init__(self, product_id: int, name: str, price: float, stock: int): - self.product_id: int = product_id - self.name: str = name - self.price: float = price - self.stock: int = stock - - def update_stock(self, amount: int) -> None: - """Updates the stock quantity for the product.""" - self.stock += amount - - def get_product_info(self) -> str: - """Returns the product's information.""" - return f'Product ID: {self.product_id}, Name: {self.name}, Price: ${self.price}, Stock: {self.stock}'