Skip to content

Commit

Permalink
fix _get_classes_from_module
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Jun 25, 2024
1 parent 3ce60af commit 62dccc7
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions src/umlizer/class_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,25 +329,40 @@ def _get_classes_from_module(module_path: str) -> list[Type[Any]]:
"""
module_path, module_name = _extract_module_name(module_path)
original_path = copy.deepcopy(sys.path)

sys.path.insert(0, module_path)
try:
sys.path.insert(0, module_path)
module = importlib.import_module(module_name)
sys.path = original_path
classes_list = [
getattr(module, o)
for o in dir(module)
if inspect.isclass(getattr(module, o)) and not o.startswith('__')
]
return classes_list
except KeyboardInterrupt:
raise_error('KeyboardInterrupt', 1)
except Exception as e:
print(f' Error loading module {module_name} '.center(80, '='))
short_module_path = '.'.join(module_path.split('/')[-3:])
print(f' Error loading module {short_module_path} '.center(80, '='))
print(e)
print('.' * 80)
sys.path = original_path
return []
return classes_list

# If __all__ is defined, get only the classes listed in __all__
all_classes_exported = []
if hasattr(module, '__all__'):
for name in module.__all__:
if not inspect.isclass(getattr(module, name)):
continue
all_classes_exported.append(getattr(module, name))

# Get all classes defined directly in the module
all_classes = []
for name in dir(module):
if not (
inspect.isclass(getattr(module, name))
and getattr(getattr(module, name), '__module__', None)
== module.__name__
):
continue
all_classes.append(getattr(module, name))
sys.path = original_path
return all_classes + all_classes_exported


def create_diagram(
Expand Down Expand Up @@ -408,7 +423,12 @@ def load_classes_definition(
raise_error(f'Path "{path_str}" doesn\'t exist.', 1)
if os.path.isdir(path_str):
sys.path.insert(0, path_str)
exclude_pattern = [exclude.strip() for exclude in exclude.split(',')]
if exclude:
exclude_pattern = [
exclude.strip() for exclude in exclude.split(',')
]
else:
exclude_pattern = []
exclude_pattern.append('__pycache__')
module_files.extend(
_search_modules(path_str, exclude_pattern=exclude_pattern)
Expand Down

0 comments on commit 62dccc7

Please sign in to comment.