From 62dccc772e8aa95d6131c2c3d2035e16c0d3325f Mon Sep 17 00:00:00 2001 From: Ivan Ogasawara Date: Tue, 25 Jun 2024 19:58:00 -0400 Subject: [PATCH] fix _get_classes_from_module --- src/umlizer/class_graph.py | 42 ++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/src/umlizer/class_graph.py b/src/umlizer/class_graph.py index 1ecef80..b7a9340 100644 --- a/src/umlizer/class_graph.py +++ b/src/umlizer/class_graph.py @@ -329,25 +329,40 @@ def _get_classes_from_module(module_path: str) -> list[Type[Any]]: """ module_path, module_name = _extract_module_name(module_path) original_path = copy.deepcopy(sys.path) + + sys.path.insert(0, module_path) try: - sys.path.insert(0, module_path) module = importlib.import_module(module_name) - sys.path = original_path - classes_list = [ - getattr(module, o) - for o in dir(module) - if inspect.isclass(getattr(module, o)) and not o.startswith('__') - ] - return classes_list except KeyboardInterrupt: raise_error('KeyboardInterrupt', 1) except Exception as e: - print(f' Error loading module {module_name} '.center(80, '=')) + short_module_path = '.'.join(module_path.split('/')[-3:]) + print(f' Error loading module {short_module_path} '.center(80, '=')) print(e) print('.' * 80) sys.path = original_path return [] - return classes_list + + # If __all__ is defined, get only the classes listed in __all__ + all_classes_exported = [] + if hasattr(module, '__all__'): + for name in module.__all__: + if not inspect.isclass(getattr(module, name)): + continue + all_classes_exported.append(getattr(module, name)) + + # Get all classes defined directly in the module + all_classes = [] + for name in dir(module): + if not ( + inspect.isclass(getattr(module, name)) + and getattr(getattr(module, name), '__module__', None) + == module.__name__ + ): + continue + all_classes.append(getattr(module, name)) + sys.path = original_path + return all_classes + all_classes_exported def create_diagram( @@ -408,7 +423,12 @@ def load_classes_definition( raise_error(f'Path "{path_str}" doesn\'t exist.', 1) if os.path.isdir(path_str): sys.path.insert(0, path_str) - exclude_pattern = [exclude.strip() for exclude in exclude.split(',')] + if exclude: + exclude_pattern = [ + exclude.strip() for exclude in exclude.split(',') + ] + else: + exclude_pattern = [] exclude_pattern.append('__pycache__') module_files.extend( _search_modules(path_str, exclude_pattern=exclude_pattern)