diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py index 009d84a..29ea142 100644 --- a/ariadne_codegen/contrib/client_forward_refs.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -42,7 +42,7 @@ def __init__(self, schema: GraphQLSchema, config_dict: Dict) -> None: # Imported classes are classes imported from local imports. We keep a # map between name and module so we know how to import them in each # method. - self.imported_classes: Dict[str, str] = {} + self.imported_classes: Dict[str, tuple[int, str]] = {} # Imported classes in each method definition. self.imported_in_method: Set[str] = set() @@ -116,9 +116,8 @@ def _store_imported_classes(self, module_body: List[ast.stmt]): continue for name in node.names: - from_ = "." * node.level + node.module if isinstance(name, ast.alias): - self.imported_classes[name.name] = from_ + self.imported_classes[name.name] = (node.level, node.module) def _rewrite_input_args_to_constants( self, method_def: Union[ast.FunctionDef, ast.AsyncFunctionDef] @@ -178,12 +177,14 @@ def _insert_import_statement_in_method( # We add the class to our set of imported in methods - these classes # don't need to be imported at all in the global scope. self.imported_in_method.add(import_class_name) + + level, module_name = self.imported_classes[import_class_name] method_def.body.insert( 0, ast.ImportFrom( - module=self.imported_classes[import_class_name], + module=module_name, names=[import_class], - level=1, + level=level, ), ) @@ -342,10 +343,12 @@ def _add_forward_ref_imports( """ type_checking_imports = {} for cls in self.input_and_return_types: - module_name = self.imported_classes[cls] + level, module_name = self.imported_classes[cls] if module_name not in type_checking_imports: type_checking_imports[module_name] = ast.ImportFrom( - module=module_name, names=[], level=1 + module=module_name, + names=[], + level=level, ) type_checking_imports[module_name].names.append(ast.alias(cls)) @@ -364,7 +367,7 @@ def _add_forward_ref_imports( ast.ImportFrom( module=TYPE_CHECKING_MODULE, names=[ast.alias(TYPE_CHECKING_FLAG)], - level=1, + level=0, ), ) diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 3cf9f1d..91e193e 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -213,6 +213,36 @@ def test_main_shows_version(): "example_client", CLIENTS_PATH / "custom_sync_query_builder" / "expected_client", ), + ( + ( + CLIENTS_PATH / "client_forward_refs" / "pyproject.toml", + ( + CLIENTS_PATH / "client_forward_refs" / "queries.graphql", + CLIENTS_PATH / "client_forward_refs" / "schema.graphql", + CLIENTS_PATH / "client_forward_refs" / "custom_scalars.py", + ), + ), + "client_forward_refs", + CLIENTS_PATH / "client_forward_refs" / "expected_client", + ), + ( + ( + CLIENTS_PATH / "client_forward_refs_shorter_results" / "pyproject.toml", + ( + CLIENTS_PATH + / "client_forward_refs_shorter_results" + / "queries.graphql", + CLIENTS_PATH + / "client_forward_refs_shorter_results" + / "schema.graphql", + CLIENTS_PATH + / "client_forward_refs_shorter_results" + / "custom_scalars.py", + ), + ), + "client_forward_refs_shorter_results", + CLIENTS_PATH / "client_forward_refs_shorter_results" / "expected_client", + ), ], indirect=["project_dir"], )