Skip to content

Commit

Permalink
Handle backlink collisions for SQLAlchemy model generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
vpetrovykh committed Dec 4, 2024
1 parent 2ba52c8 commit 61c9c75
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 11 deletions.
28 changes: 28 additions & 0 deletions gel/orm/introspection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import re
import collections


INTRO_QUERY = '''
Expand Down Expand Up @@ -100,6 +101,7 @@ def _process_links(types, modules):
for spec in types:
check_name(spec['name'])
type_map[spec['name']] = spec
spec['backlink_renames'] = {}

for prop in spec['properties']:
check_name(prop['name'])
Expand Down Expand Up @@ -197,6 +199,32 @@ def _process_links(types, modules):
'target': target,
})

# Go over backlinks and resolve any name collisions using the type map.
for spec in types:
mod = spec["name"].rsplit('::', 1)[0]
sql_source = get_sql_name(spec["name"])

# Find collisions in backlink names
bk = collections.defaultdict(list)
for link in spec['backlinks']:
if link['name'].startswith('backlink_via_'):
bk[link['name']].append(link)

for bklinks in bk.values():
if len(bklinks) > 1:
# We have a collision, so each backlink in it must now be
# disambiguated.
for link in bklinks:
origsrc = get_sql_name(link['target']['name'])
lname = link['name']
link['name'] = f'{lname}_from_{origsrc}'
# Also update the original source of the link with the
# special backlink name.
source = type_map[link['target']['name']]
fwname = lname.replace('backlink_via_', '', 1)
link['fwname'] = fwname
source['backlink_renames'][fwname] = link['name']

return {
'modules': modules,
'object_types': types,
Expand Down
25 changes: 15 additions & 10 deletions gel/orm/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,10 @@ def render_models(self, spec):
modules[mod]['prop_objects'][pobj['name']] = pobj

for rec in spec['object_types']:
mod, _ = get_mod_and_name(rec['name'])
mod, name = get_mod_and_name(rec['name'])
if 'object_types' not in modules[mod]:
modules[mod]['object_types'] = []
modules[mod]['object_types'].append(rec)
modules[mod]['object_types'] = {}
modules[mod]['object_types'][name] = rec

# Initialize the base directory
self.init_dir(self.outdir)
Expand All @@ -215,13 +215,13 @@ def render_models(self, spec):

for lobj in maps.get('link_objects', {}).values():
self.write()
self.render_link_object(lobj)
self.render_link_object(lobj, modules)

for pobj in maps.get('prop_objects', {}).values():
self.write()
self.render_prop_object(pobj)

for rec in maps.get('object_types', []):
for rec in maps.get('object_types', {}).values():
self.write()
self.render_type(rec, modules)

Expand All @@ -243,11 +243,11 @@ def render_link_table(self, spec):
self.dedent()
self.write(f')')

def render_link_object(self, spec):
def render_link_object(self, spec, modules):
mod = spec['module']
name = spec['name']
sql_name = spec['table']
source_link = sql_name.split('.')[-1]
source_name, source_link = sql_name.split('.')

self.write()
self.write(f'class {name}(Base):')
Expand Down Expand Up @@ -281,7 +281,11 @@ def render_link_object(self, spec):
if lname == 'source':
bklink = source_link
else:
bklink = f'backlink_via_{source_link}'
src = modules[mod]['object_types'][source_name]
bklink = src['backlink_renames'].get(
source_link,
f'backlink_via_{source_link}',
)

self.write(
f'{lname}: Mapped[{pyname}] = '
Expand Down Expand Up @@ -446,8 +450,9 @@ def render_link(self, spec, mod, parent, modules):
name = spec['name']
nullable = not spec['required']
tmod, target = get_mod_and_name(spec['target']['name'])
source = modules[mod]['object_types'][parent]
cardinality = spec['cardinality']
bklink = f'backlink_via_{name}'
bklink = source['backlink_renames'].get(name, f'backlink_via_{name}')

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
Expand Down Expand Up @@ -514,7 +519,7 @@ def render_backlink(self, spec, mod, modules):
tmod, target = get_mod_and_name(spec['target']['name'])
cardinality = spec['cardinality']
exclusive = spec['exclusive']
bklink = name.replace('backlink_via_', '', 1)
bklink = spec.get('fwname', name.replace('backlink_via_', '', 1))

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
Expand Down
20 changes: 20 additions & 0 deletions tests/dbsetup/features.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,24 @@ insert Theme {
branch := (
select other::Branch{@note := 'fall'} filter .val = 'big'
)
};

insert Foo {
name := 'foo'
};

insert Foo {
name := 'oof'
};

insert Bar {
n := 123,
foo := assert_single((select Foo filter .name = 'foo')),
many_foo := Foo,
};

insert Who {
x := 456,
foo := assert_single((select Foo filter .name = 'oof')),
many_foo := (select Foo{@note := 'just one'} filter .name = 'foo'),
};
20 changes: 19 additions & 1 deletion tests/dbsetup/features_default.esdl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,22 @@ type Theme {
link branch: other::Branch {
property note: str;
}
};
};

type Foo {
required property name: str;
};

type Bar {
link foo: Foo;
multi link many_foo: Foo;
required property n: int64;
};

type Who {
link foo: Foo;
multi link many_foo: Foo {
property note: str;
};
required property x: int64;
};
29 changes: 29 additions & 0 deletions tests/test_sqla_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,32 @@ def test_sqla_module_02(self):
('orange', 'swapped', 'small'),
},
)

def test_sqla_bklink_01(self):
# test backlink name collisions
foo = self.sess.query(self.sm.Foo).filter_by(name='foo').one()
oof = self.sess.query(self.sm.Foo).filter_by(name='oof').one()

# only one link from Bar 123 to foo
self.assertEqual(
[obj.n for obj in foo.backlink_via_foo_from_Bar],
[123],
)
# only one link from Who 456 to oof
self.assertEqual(
[obj.x for obj in oof.backlink_via_foo_from_Who],
[456],
)

# foo is linked via `many_foo` from both Bar and Who
self.assertEqual(
[obj.n for obj in foo.backlink_via_many_foo_from_Bar],
[123],
)
self.assertEqual(
[
(obj.note, obj.source.x)
for obj in foo.backlink_via_many_foo_from_Who
],
[('just one', 456)],
)

0 comments on commit 61c9c75

Please sign in to comment.