diff --git a/src/umlizer/inspector.py b/src/umlizer/inspector.py index 444a9f5..ef637c3 100644 --- a/src/umlizer/inspector.py +++ b/src/umlizer/inspector.py @@ -288,55 +288,60 @@ def _extract_module_name(module_path: str) -> tuple[str, str]: return module_path, module_name -def _get_classes_from_module(module_path: str) -> list[Type[Any]]: +def _get_classes_from_module(module_file_path: str) -> list[Type[Any]]: """ - Extract classes from a given module path using importlib.import_module. + Extract classes from a given module path using importlib. Parameters ---------- - module_path : str - The path to the module from which classes are to be extracted. + module_file_path : str + The path to the module file from which classes are to be extracted. Returns ------- list A list of class objects. """ - module_path, module_name = _extract_module_name(module_path) + module_path, module_name = _extract_module_name(module_file_path) original_path = copy.deepcopy(sys.path) sys.path.insert(0, module_path) try: - module = importlib.import_module(module_name) + spec = importlib.util.spec_from_file_location( + module_name, module_file_path + ) + if spec is None: + raise ImportError(f'Cannot find spec for {module_file_path}') + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + sys.path = original_path + + all_classes_exported = [] + + if hasattr(module, '__all__'): + all_classes_exported = [ + getattr(module, name) + for name in module.__all__ + if inspect.isclass(getattr(module, name)) + ] + + all_classes = [ + getattr(module, name) + for name in dir(module) + if inspect.isclass(getattr(module, name)) + and getattr(getattr(module, name), '__module__', None) + == module.__name__ + ] except KeyboardInterrupt: raise_error('KeyboardInterrupt', 1) except Exception as e: - short_module_path = '.'.join(module_path.split('/')[-3:]) + short_module_path = '.'.join(module_path.split(os.sep)[-3:]) print(f' Error loading module {short_module_path} '.center(80, '=')) print(e) print('.' * 80) sys.path = original_path return [] - - # If __all__ is defined, get only the classes listed in __all__ - all_classes_exported = [] - if hasattr(module, '__all__'): - for name in module.__all__: - if not inspect.isclass(getattr(module, name)): - continue - all_classes_exported.append(getattr(module, name)) - - # Get all classes defined directly in the module - all_classes = [] - for name in dir(module): - if not ( - inspect.isclass(getattr(module, name)) - and getattr(getattr(module, name), '__module__', None) - == module.__name__ - ): - continue - all_classes.append(getattr(module, name)) - sys.path = original_path return all_classes + all_classes_exported @@ -390,10 +395,12 @@ def load_classes_definition( module_files.append(path_str) for file_path in module_files: + classes_from_module = _get_classes_from_module(file_path) + classes_list.extend(classes_from_module) if verbose: + print('=' * 80) print(file_path) - - classes_list.extend(_get_classes_from_module(file_path)) + print(classes_from_module) return [_get_class_structure(cls, source) for cls in classes_list]