Skip to content

Commit

Permalink
fix issues with identically module names
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Jun 27, 2024
1 parent daa0b5d commit 0b5ca49
Showing 1 changed file with 36 additions and 29 deletions.
65 changes: 36 additions & 29 deletions src/umlizer/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,55 +288,60 @@ def _extract_module_name(module_path: str) -> tuple[str, str]:
return module_path, module_name


def _get_classes_from_module(module_path: str) -> list[Type[Any]]:
def _get_classes_from_module(module_file_path: str) -> list[Type[Any]]:
"""
Extract classes from a given module path using importlib.import_module.
Extract classes from a given module path using importlib.
Parameters
----------
module_path : str
The path to the module from which classes are to be extracted.
module_file_path : str
The path to the module file from which classes are to be extracted.
Returns
-------
list
A list of class objects.
"""
module_path, module_name = _extract_module_name(module_path)
module_path, module_name = _extract_module_name(module_file_path)
original_path = copy.deepcopy(sys.path)

sys.path.insert(0, module_path)
try:
module = importlib.import_module(module_name)
spec = importlib.util.spec_from_file_location(
module_name, module_file_path
)
if spec is None:
raise ImportError(f'Cannot find spec for {module_file_path}')
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore
sys.path = original_path

all_classes_exported = []

if hasattr(module, '__all__'):
all_classes_exported = [
getattr(module, name)
for name in module.__all__
if inspect.isclass(getattr(module, name))
]

all_classes = [
getattr(module, name)
for name in dir(module)
if inspect.isclass(getattr(module, name))
and getattr(getattr(module, name), '__module__', None)
== module.__name__
]
except KeyboardInterrupt:
raise_error('KeyboardInterrupt', 1)
except Exception as e:
short_module_path = '.'.join(module_path.split('/')[-3:])
short_module_path = '.'.join(module_path.split(os.sep)[-3:])
print(f' Error loading module {short_module_path} '.center(80, '='))
print(e)
print('.' * 80)
sys.path = original_path
return []

# If __all__ is defined, get only the classes listed in __all__
all_classes_exported = []
if hasattr(module, '__all__'):
for name in module.__all__:
if not inspect.isclass(getattr(module, name)):
continue
all_classes_exported.append(getattr(module, name))

# Get all classes defined directly in the module
all_classes = []
for name in dir(module):
if not (
inspect.isclass(getattr(module, name))
and getattr(getattr(module, name), '__module__', None)
== module.__name__
):
continue
all_classes.append(getattr(module, name))
sys.path = original_path
return all_classes + all_classes_exported


Expand Down Expand Up @@ -390,10 +395,12 @@ def load_classes_definition(
module_files.append(path_str)

for file_path in module_files:
classes_from_module = _get_classes_from_module(file_path)
classes_list.extend(classes_from_module)
if verbose:
print('=' * 80)
print(file_path)

classes_list.extend(_get_classes_from_module(file_path))
print(classes_from_module)

return [_get_class_structure(cls, source) for cls in classes_list]

Expand Down

0 comments on commit 0b5ca49

Please sign in to comment.