Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading compounds from Madx for repeated element names #469

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/test_compounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,34 @@ def test_slicing_preserve_thick_compound_if_unsliced():
]


def test_madloader_compounds_repeated_elements():
mad = Madx(stdout=False)
mad.options.rbarc = False
mad.input(f"""
mb: sbend, l:=1, angle:=0.1, k0:=0.2, k1=0.1;
ss: sequence, l:=2, refer=entry;
mb, at=0;
mb, at=1;
endsequence;
""")
mad.beam()
mad.use(sequence='ss')

line = xt.Line.from_madx_sequence(
sequence=mad.sequence.ss,
deferred_expressions=True,
allow_thick=True,
)

mb_compound_expected = ['mb_entry', 'mb_den', 'mb', 'mb_dex', 'mb_exit']
mb_compound_result = line.get_compound_subsequence('mb')
assert mb_compound_result == mb_compound_expected

mb0_compound_expected = ['mb_entry:0', 'mb_den:0', 'mb:0', 'mb_dex:0', 'mb_exit:0']
mb0_compound_result = line.get_compound_subsequence('mb:0')
assert mb0_compound_result == mb0_compound_expected


@pytest.fixture(scope='function')
def line_with_compounds(temp_context_default_func):
# The fixture `temp_context_default_func` is defined in conftest.py and is
Expand Down
20 changes: 12 additions & 8 deletions xtrack/mad_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def add_to_line(self, line, buffer):
xtel = self.type(**self.attrs, _buffer=buffer)
name = generate_repeated_name(line, self.name)
line.append_element(xtel, name)
return name


class ElementBuilderWithExpr(ElementBuilder):
Expand All @@ -329,7 +330,7 @@ def add_to_line(self, line, buffer):
elref = line.element_refs[name]
for k, p in self.attrs.items():
set_expr(elref, k, p)
return xtel
return name


class CompoundElementBuilder:
Expand Down Expand Up @@ -367,21 +368,24 @@ def add_to_line(self, line, buffer):
[end_marker]
)

new_name_map = {}
for el in component_elements:
el.add_to_line(line, buffer)
added_name = el.add_to_line(line, buffer)
new_name_map[el.name] = added_name

def _get_names(builder_elements):
return [elem.name for elem in builder_elements]
return [new_name_map[elem.name] for elem in builder_elements]

compound = Compound(
core=_get_names(self.core),
aperture=_get_names(self.aperture),
entry_transform=_get_names(self.entry_transform),
exit_transform=_get_names(self.exit_transform),
entry=start_marker.name,
exit_=end_marker.name,
entry=new_name_map[start_marker.name],
exit_=new_name_map[end_marker.name],
)
line.compound_container.define_compound(self.name, compound)
line.compound_container.define_compound(new_name_map[self.name], compound)
return list(new_name_map.values())


class Aperture:
Expand Down Expand Up @@ -746,8 +750,8 @@ def add_elements(
):
out = {} # tbc
for el in elements:
xt_element = el.add_to_line(line, buffer)
out[el.name] = xt_element # tbc
name = el.add_to_line(line, buffer)
out[el.name] = name # tbc
return out # tbc

@property
Expand Down