From 58753b7ec1ee948ca6416814cb3512bc83432d7f Mon Sep 17 00:00:00 2001 From: Victor Petrovykh Date: Wed, 27 Nov 2024 13:36:06 -0500 Subject: [PATCH] Handle backlink collisions for SQLAlchemy model generator. --- gel/orm/introspection.py | 28 ++++++++++++++++++++++++++++ gel/orm/sqla.py | 25 +++++++++++++++---------- tests/dbsetup/features.edgeql | 20 ++++++++++++++++++++ tests/dbsetup/features_default.esdl | 20 +++++++++++++++++++- tests/test_sqla_features.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 111 insertions(+), 11 deletions(-) diff --git a/gel/orm/introspection.py b/gel/orm/introspection.py index 3d964de7..f19bffd7 100644 --- a/gel/orm/introspection.py +++ b/gel/orm/introspection.py @@ -1,5 +1,6 @@ import json import re +import collections INTRO_QUERY = ''' @@ -100,6 +101,7 @@ def _process_links(types, modules): for spec in types: check_name(spec['name']) type_map[spec['name']] = spec + spec['backlink_renames'] = {} for prop in spec['properties']: check_name(prop['name']) @@ -197,6 +199,32 @@ def _process_links(types, modules): 'target': target, }) + # Go over backlinks and resolve any name collisions using the type map. + for spec in types: + mod = spec["name"].rsplit('::', 1)[0] + sql_source = get_sql_name(spec["name"]) + + # Find collisions in backlink names + bk = collections.defaultdict(list) + for link in spec['backlinks']: + if link['name'].startswith('backlink_via_'): + bk[link['name']].append(link) + + for bklinks in bk.values(): + if len(bklinks) > 1: + # We have a collision, so each backlink in it must now be + # disambiguated. + for link in bklinks: + origsrc = get_sql_name(link['target']['name']) + lname = link['name'] + link['name'] = f'{lname}_from_{origsrc}' + # Also update the original source of the link with the + # special backlink name. + source = type_map[link['target']['name']] + fwname = lname.replace('backlink_via_', '', 1) + link['fwname'] = fwname + source['backlink_renames'][fwname] = link['name'] + return { 'modules': modules, 'object_types': types, diff --git a/gel/orm/sqla.py b/gel/orm/sqla.py index bbff6f41..bf18f264 100644 --- a/gel/orm/sqla.py +++ b/gel/orm/sqla.py @@ -189,10 +189,10 @@ def render_models(self, spec): modules[mod]['prop_objects'][pobj['name']] = pobj for rec in spec['object_types']: - mod, _ = get_mod_and_name(rec['name']) + mod, name = get_mod_and_name(rec['name']) if 'object_types' not in modules[mod]: - modules[mod]['object_types'] = [] - modules[mod]['object_types'].append(rec) + modules[mod]['object_types'] = {} + modules[mod]['object_types'][name] = rec # Initialize the base directory self.init_dir(self.outdir) @@ -215,13 +215,13 @@ def render_models(self, spec): for lobj in maps.get('link_objects', {}).values(): self.write() - self.render_link_object(lobj) + self.render_link_object(lobj, modules) for pobj in maps.get('prop_objects', {}).values(): self.write() self.render_prop_object(pobj) - for rec in maps.get('object_types', []): + for rec in maps.get('object_types', {}).values(): self.write() self.render_type(rec, modules) @@ -243,11 +243,11 @@ def render_link_table(self, spec): self.dedent() self.write(f')') - def render_link_object(self, spec): + def render_link_object(self, spec, modules): mod = spec['module'] name = spec['name'] sql_name = spec['table'] - source_link = sql_name.split('.')[-1] + source_name, source_link = sql_name.split('.') self.write() self.write(f'class {name}(Base):') @@ -281,7 +281,11 @@ def render_link_object(self, spec): if lname == 'source': bklink = source_link else: - bklink = f'backlink_via_{source_link}' + src = modules[mod]['object_types'][source_name] + bklink = src['backlink_renames'].get( + source_link, + f'backlink_via_{source_link}', + ) self.write( f'{lname}: Mapped[{pyname}] = ' @@ -446,8 +450,9 @@ def render_link(self, spec, mod, parent, modules): name = spec['name'] nullable = not spec['required'] tmod, target = get_mod_and_name(spec['target']['name']) + source = modules[mod]['object_types'][parent] cardinality = spec['cardinality'] - bklink = f'backlink_via_{name}' + bklink = source['backlink_renames'].get(name, f'backlink_via_{name}') if spec.get('has_link_object'): # intermediate object will have the actual source and target @@ -514,7 +519,7 @@ def render_backlink(self, spec, mod, modules): tmod, target = get_mod_and_name(spec['target']['name']) cardinality = spec['cardinality'] exclusive = spec['exclusive'] - bklink = name.replace('backlink_via_', '', 1) + bklink = spec.get('fwname', name.replace('backlink_via_', '', 1)) if spec.get('has_link_object'): # intermediate object will have the actual source and target diff --git a/tests/dbsetup/features.edgeql b/tests/dbsetup/features.edgeql index 55ce8cf7..7e89aa6e 100644 --- a/tests/dbsetup/features.edgeql +++ b/tests/dbsetup/features.edgeql @@ -59,4 +59,24 @@ insert Theme { branch := ( select other::Branch{@note := 'fall'} filter .val = 'big' ) +}; + +insert Foo { + name := 'foo' +}; + +insert Foo { + name := 'oof' +}; + +insert Bar { + n := 123, + foo := assert_single((select Foo filter .name = 'foo')), + many_foo := Foo, +}; + +insert Who { + x := 456, + foo := assert_single((select Foo filter .name = 'oof')), + many_foo := (select Foo{@note := 'just one'} filter .name = 'foo'), }; \ No newline at end of file diff --git a/tests/dbsetup/features_default.esdl b/tests/dbsetup/features_default.esdl index b8ffec59..59b4bd6d 100644 --- a/tests/dbsetup/features_default.esdl +++ b/tests/dbsetup/features_default.esdl @@ -26,4 +26,22 @@ type Theme { link branch: other::Branch { property note: str; } -}; \ No newline at end of file +}; + +type Foo { + required property name: str; +}; + +type Bar { + link foo: Foo; + multi link many_foo: Foo; + required property n: int64; +}; + +type Who { + link foo: Foo; + multi link many_foo: Foo { + property note: str; + }; + required property x: int64; +}; diff --git a/tests/test_sqla_features.py b/tests/test_sqla_features.py index 45d3457f..4c1e42e6 100644 --- a/tests/test_sqla_features.py +++ b/tests/test_sqla_features.py @@ -287,3 +287,32 @@ def test_sqla_module_02(self): ('orange', 'swapped', 'small'), }, ) + + def test_sqla_bklink_01(self): + # test backlink name collisions + foo = self.sess.query(self.sm.Foo).filter_by(name='foo').one() + oof = self.sess.query(self.sm.Foo).filter_by(name='oof').one() + + # only one link from Bar 123 to foo + self.assertEqual( + [obj.n for obj in foo.backlink_via_foo_from_Bar], + [123], + ) + # only one link from Who 456 to oof + self.assertEqual( + [obj.x for obj in oof.backlink_via_foo_from_Who], + [456], + ) + + # foo is linked via `many_foo` from both Bar and Who + self.assertEqual( + [obj.n for obj in foo.backlink_via_many_foo_from_Bar], + [123], + ) + self.assertEqual( + [ + (obj.note, obj.source.x) + for obj in foo.backlink_via_many_foo_from_Who + ], + [('just one', 456)], + )