Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix label handling during flatten #1208

Merged
merged 8 commits into from
Dec 6, 2024
3 changes: 3 additions & 0 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,6 +1917,9 @@ def flatten(self, inplace=True):
for neighbor in nx.neighbors(bond_graph, particle):
new_bonds.append((particle, neighbor))

# Remove all labels which refer to children in the hierarchy
self.labels.clear()

# Remove all the children
if inplace:
for child in children_list:
Expand Down
61 changes: 30 additions & 31 deletions mbuild/tests/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,42 +801,14 @@ def test_remove(self, ethane):

# Test to reset labels after hydrogens
ethane6 = mb.clone(ethane)
ethane6.flatten()
hydrogens = ethane6.particles_by_name("H")
ethane6.remove(hydrogens)
ethane6.remove(hydrogens, reset_labels=True)
assert list(ethane6.labels.keys()) == [
"methyl1",
"methyl2",
"C",
"C[0]",
"H",
"C[1]",
"port",
"port[1]",
"port[3]",
"port[5]",
"port[7]",
"port[9]",
"port[11]",
]

ethane7 = mb.clone(ethane)
ethane7.flatten()
hydrogens = ethane7.particles_by_name("H")
ethane7.remove(hydrogens, reset_labels=True)

assert list(ethane7.labels.keys()) == [
"C",
"C[0]",
"C[1]",
"port",
"port[0]",
"port[1]",
"port[2]",
"port[3]",
"port[4]",
"port[5]",
]
assert ethane6.available_ports() == []
assert len(ethane6.all_ports()) == 6

def test_remove_many(self, ethane):
ethane.remove([ethane.children[0], ethane.children[1]])
Expand Down Expand Up @@ -1065,6 +1037,33 @@ def test_flatten_box_of_eth(self, ethane):
box_of_eth.flatten()
assert len(box_of_eth.children) == box_of_eth.n_particles == 8 * 2
assert box_of_eth.n_bonds == 7 * 2
assert list(box_of_eth.labels.keys()) == [
"C",
"C[0]",
"H",
"H[0]",
"H[1]",
"H[2]",
"C[1]",
"H[3]",
"H[4]",
"H[5]",
"C[2]",
"H[6]",
"H[7]",
"H[8]",
"C[3]",
"H[9]",
"H[10]",
"H[11]",
]

def test_flatten_then_fill_box(self, benzene):
benzene.flatten(inplace=True)
benzene_box = mb.packing.fill_box(
compound=benzene, n_compounds=2, density=0.3
)
assert next(iter(benzene_box.particles())).root.bond_graph

def test_flatten_with_port(self, ethane):
ethane.remove(ethane[2])
Expand Down
Loading