Skip to content

Commit

Permalink
Merge pull request #145 from ReactionMechanismGenerator/flux_d_fixes
Browse files Browse the repository at this point in the history
Fix species identification in reaction wells when generating flux diagrams
  • Loading branch information
JintaoWu98 authored Jan 25, 2024
2 parents 408e5e3 + 8ac4929 commit 868e2de
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 36 deletions.
132 changes: 101 additions & 31 deletions t3/utils/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ def generate_flux(model_path: str,
explore_tol: float = 0.95,
dead_end_tol: float = 0.10,
generate_separate_diagrams_per_observable: bool = False,
display_flux_ratio: bool = True,
report_flux_ratio: bool = True,
report_actual_flux: bool = False,
display_concentrations: bool = True,
display_r_n_p: bool = True,
scaling: Optional[float] = None,
fix_cantera_model: bool = True,
allowed_nodes: Optional[List[str]] = None,
):
"""
Generate a flux diagram for a given model and composition.
Expand Down Expand Up @@ -61,11 +63,14 @@ def generate_flux(model_path: str,
Don't explore further consumption is lower than this tolerance
times the net rate of production.
generate_separate_diagrams_per_observable (bool, optional): Whether to generate a separate flux diagram for each observable.
display_flux_ratio (bool, optional): Whether to display the flux ratio.
report_flux_ratio (bool, optional): Whether to display the flux ratio.
report_actual_flux (bool, optional): Whether to report the actual flux values rather than the relative flux.
display_concentrations (bool, optional): Whether to display the concentrations.
display_r_n_p (bool, optional): Whether to display the other reactants and products on each arrow.
scaling (Optional[float], optional): The scaling of the final image, 100 means no scaling.
fix_cantera_model (bool, optional): Whether to fix the Cantera model before running the simulation.
allowed_nodes (Optional[List[str]], optional): A list of nodes to consider.
any node outside this list will not appear in the flux diagram.
Structures:
profiles: {<time in s>: {'P': <pressure in bar>,
Expand Down Expand Up @@ -105,9 +110,11 @@ def generate_flux(model_path: str,
explore_tol=explore_tol,
dead_end_tol=dead_end_tol,
display_concentrations=display_concentrations,
display_flux_ratio=display_flux_ratio,
report_flux_ratio=report_flux_ratio,
report_actual_flux=report_actual_flux,
display_r_n_p=display_r_n_p,
scaling=scaling,
allowed_nodes=allowed_nodes,
)
else:
generate_flux_diagrams(profiles=profiles,
Expand All @@ -116,9 +123,11 @@ def generate_flux(model_path: str,
explore_tol=explore_tol,
dead_end_tol=dead_end_tol,
display_concentrations=display_concentrations,
display_flux_ratio=display_flux_ratio,
report_flux_ratio=report_flux_ratio,
report_actual_flux=report_actual_flux,
display_r_n_p=display_r_n_p,
scaling=scaling,
allowed_nodes=allowed_nodes,
)


Expand Down Expand Up @@ -479,9 +488,11 @@ def generate_flux_diagrams(profiles: dict,
explore_tol: float = 0.95,
dead_end_tol: float = 0.10,
display_concentrations: bool = True,
display_flux_ratio: bool = True,
report_flux_ratio: bool = True,
report_actual_flux: bool = False,
display_r_n_p: bool = True,
scaling: Optional[float] = None,
allowed_nodes: Optional[List[str]] = None,
):
"""
Generate flux diagrams.
Expand All @@ -495,9 +506,12 @@ def generate_flux_diagrams(profiles: dict,
Don't explore further consumption is lower than this tolerance
times the net rate of production.
display_concentrations (bool, optional): Whether to display the concentrations.
display_flux_ratio (bool, optional): Whether to display the flux ratio.
report_flux_ratio (bool, optional): Whether to display the flux ratio.
report_actual_flux (bool, optional): Whether to report the actual flux values rather than the relative flux.
display_r_n_p (bool, optional): Whether to display the other reactants and products on each arrow.
scaling (Optional[float], optional): The scaling of the final image.
allowed_nodes (Optional[List[str]], optional): A list of nodes to consider.
any node outside this list will not appear in the flux diagram.
Structures:
graph: {<species1>: {'rxn1': [[<the species formed>], <rop_value>],
Expand All @@ -522,9 +536,11 @@ def generate_flux_diagrams(profiles: dict,
max_rop=max_rop,
folder_path=folder_path,
display_concentrations=display_concentrations,
display_flux_ratio=display_flux_ratio,
report_flux_ratio=report_flux_ratio,
report_actual_flux=report_actual_flux,
display_r_n_p=display_r_n_p,
scaling=scaling,
allowed_nodes=allowed_nodes,
)


Expand All @@ -537,9 +553,11 @@ def create_digraph(flux_graph: dict,
max_rop: float,
folder_path: str,
display_concentrations: bool = True,
display_flux_ratio: bool = True,
report_flux_ratio: bool = True,
report_actual_flux: bool = False,
display_r_n_p: bool = True,
scaling: Optional[float] = None,
allowed_nodes: Optional[List[str]] = None,
) -> None:
"""
Create a directed graph from the flux graph and save it as a .dot file.
Expand All @@ -554,9 +572,12 @@ def create_digraph(flux_graph: dict,
max_rop (float): The absolute maximal ROP value.
folder_path (str): The path to the folder in which to save the flux diagrams and accompanied data.
display_concentrations (bool, optional): Whether to display the concentrations.
display_flux_ratio (bool, optional): Whether to display the flux ratio.
report_flux_ratio (bool, optional): Whether to display the flux ratio.
report_actual_flux (bool, optional): Whether to report the actual flux values rather than the relative flux.
display_r_n_p (bool, optional): Whether to display the other reactants and products on each arrow.
scaling (Optional[float], optional): The scaling of the final image.
allowed_nodes (Optional[List[str]], optional): A list of nodes to consider.
any node outside this list will not appear in the flux diagram.
"""
if not os.path.isdir(folder_path):
os.makedirs(folder_path)
Expand All @@ -568,6 +589,10 @@ def create_digraph(flux_graph: dict,
for rop_list in rxn_dict.values():
species_to_consider.update(rop_list[0])
xs = [v for k, v in profile['X'].items() if k in species_to_consider]
if not len(xs):
print(f'Could not create a flux diagram for observables {observables} at {time} s. '
f'Could not simulate the system.')
return
x_max, x_min = max(xs), min(xs)
abs_rops = [abs(values[1]) for inner_dict in flux_graph.values() for values in inner_dict.values()]
rop_min, rop_max = min(abs_rops), max(abs_rops)
Expand All @@ -590,7 +615,8 @@ def create_digraph(flux_graph: dict,
for downstream_node_label in downstream_node_labels:
if downstream_node_label in nodes_to_explore and downstream_node_label not in visited:
visited.add(downstream_node_label)
stack.append(downstream_node_label)
if allowed_nodes is None or downstream_node_label in allowed_nodes:
stack.append(downstream_node_label)
downstream_nodes = [get_node(graph=graph,
label=downstream_node_label,
nodes=nodes,
Expand All @@ -607,24 +633,30 @@ def create_digraph(flux_graph: dict,
downstream_nodes=downstream_nodes,
downstream_node_labels=downstream_node_labels,
width=get_width(x=rop, x_min=rop_min, x_max=rop_max),
rel_rop=abs(rop) / abs(rop_max),
rop=rop,
max_rop=max_rop,
rxn=rxn_string,
multipliers=multipliers,
display_flux_ratio=display_flux_ratio,
report_flux_ratio=report_flux_ratio,
report_actual_flux=report_actual_flux,
display_r_n_p=display_r_n_p,
allowed_nodes=allowed_nodes,
)
graph.set(name='label', value=f'Flux diagram at {time} s, ROP range: [{min_rop:.2e}, {max_rop:.2e}] ' +
'mol/cm\N{SUPERSCRIPT THREE}/s)')
if scaling is not None:
graph.set('size', f'{scaling},{scaling}')
if allowed_nodes is not None:
for node in graph.get_nodes():
if node.get_name() not in allowed_nodes:
graph.del_node(node)
graph_dot_path = os.path.join(folder_path, f'flux_diagram_{time}_s.dot')
graph_png_path = os.path.join(folder_path, f'flux_diagram_{time}_s.png')
graph.write(graph_dot_path)
try:
graph.write_png(graph_png_path)
except AssertionError:
print(f'Could not create a flux diagram for observables {observables} at {time} s.')
# todo: add smiles in transparent, and add the species label


def add_edges(graph: pydot.Dot,
Expand All @@ -633,11 +665,14 @@ def add_edges(graph: pydot.Dot,
downstream_nodes: List[pydot.Node],
downstream_node_labels: List[str],
width: float,
rel_rop: float,
rop: float,
max_rop: float,
rxn: Optional[str] = None,
multipliers: Optional[List[float]] = None,
display_flux_ratio: bool = True,
report_flux_ratio: bool = True,
report_actual_flux: bool = False,
display_r_n_p: bool = True,
allowed_nodes: Optional[List[str]] = None,
):
"""
Add edges to the graph.
Expand All @@ -650,24 +685,34 @@ def add_edges(graph: pydot.Dot,
downstream_node_labels (List[str]): The downstream node labels.
rxn (str): The reaction string.
width (float): The edge width.
rop (float): The normalized ROP value.
max_rop (float): The maximal ROP value.
multipliers (List[float]): The stoichiometric multipliers.
display_flux_ratio (bool, optional): Whether to display the flux ratio.
report_flux_ratio (bool, optional): Whether to display the flux ratio.
report_actual_flux (bool, optional): Whether to report the actual flux values rather than the relative flux.
display_r_n_p (bool, optional): Whether to display the other reactants and products on each arrow.
allowed_nodes (Optional[List[str]], optional): A list of nodes to consider.
any node outside this list will not appear in the flux diagram.
"""
for multiplier, node, node_label in zip(multipliers, downstream_nodes, downstream_node_labels):
rs, ps = get_other_reactants_and_products(rxn=rxn, spcs=[origin_label, node_label])
edge = pydot.Edge(origin_node, node, penwidth=width + np.log10(multiplier), fontsize=8)
label = ''
if display_flux_ratio:
label = f'{rel_rop:.1f}' if rel_rop > 0.1 else f'{rel_rop:.1e}'
if display_r_n_p and rs:
label += f'\n{rs}'
if display_r_n_p and ps:
label += f'\n{ps}'
if label != '':
edge.set('label', label)
edge.set('arrowhead', 'vee')
graph.add_edge(edge)
if allowed_nodes is None or (origin_label in allowed_nodes and node_label in allowed_nodes):
rs, ps = get_other_reactants_and_products(rxn=rxn, spcs=[origin_label, node_label])
edge = pydot.Edge(origin_node, node, penwidth=width + np.log10(multiplier), fontsize=8)
label = ''
rop = abs(rop)
if report_flux_ratio:
label = f'{rop:.1f}' if rop > 0.1 else f'{rop:.1e}'
elif report_actual_flux:
actual_rop = rop * max_rop
label = f'{actual_rop:.1f}' if actual_rop > 0.1 else f'{actual_rop:.1e}'
if display_r_n_p and rs:
label += f'\n{rs}'
if display_r_n_p and ps:
label += f'\n{ps}'
if label != '':
edge.set('label', label)
edge.set('arrowhead', 'vee')
graph.add_edge(edge)


def get_width(x: float,
Expand All @@ -690,6 +735,8 @@ def get_width(x: float,
max_width, min_width = 4, 0.2
x, x_min, x_max = abs(x), abs(x_min), abs(x_max)
if not log_scale:
if x == x_min == x_max:
return 1
return min_width + (x - x_min) * (max_width - min_width) / (x_max - x_min)
return get_width(x=-np.log10(x_min) - np.log10(x_max) + np.log10(x),
x_min=-np.log10(x_max),
Expand Down Expand Up @@ -821,7 +868,7 @@ def get_flux_graph(profile: dict,
graph[node][rxn][1] += rop
else:
graph[node][rxn] = [opposite_rxn_species, rop]
min_rop = min_rop * max_rop
min_rop = min_rop * max_rop if min_rop is not None else 0
return graph, nodes_to_explore, min_rop, max_rop


Expand Down Expand Up @@ -888,14 +935,37 @@ def get_opposite_rxn_species(rxn: str, spc: str) -> List[str]:
"""
arrow = ' <=> ' if ' <=> ' in rxn else ' => '
wells = rxn.split(arrow)
counts = wells[0].count(spc), wells[1].count(spc)
counts = count_species_in_well(well=wells[0], spc=spc), count_species_in_well(well=wells[1], spc=spc)
i = int(counts[0] > counts[1])
species = wells[i].split(' + ')
for token in [' + M', ' (+M)', 'M', ' + ']:
species = [s.replace(token, '') for s in species]
return [s for s in species if s != '']


def count_species_in_well(well: str,
spc: str,
) -> int:
"""
Count the number of times a species appears in a well.
Args:
well (str): The well string.
spc (str): The species label.
Returns:
int: The number of times a species appears in the well.
"""
count = 0
for token in [' + M', ' (+M)', 'M']:
well = well.replace(token, '')
splits = well.split(' + ')
for s in splits:
if s == spc:
count += 1
return count


def unpack_stoichiometry(labels: List[str]) -> Tuple[List[str], List[int]]:
"""
Unpack stoichiometry.
Expand Down
26 changes: 21 additions & 5 deletions tests/test_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_generate_flux():
explore_tol=0.95,
dead_end_tol=0.10,
generate_separate_diagrams_per_observable=False,
display_flux_ratio=True,
report_flux_ratio=True,
display_concentrations=True,
display_r_n_p=True,
fix_cantera_model=False,
Expand All @@ -54,7 +54,7 @@ def test_generate_flux():
explore_tol=0.95,
dead_end_tol=0.10,
generate_separate_diagrams_per_observable=True,
display_flux_ratio=True,
report_flux_ratio=True,
display_concentrations=True,
display_r_n_p=True,
fix_cantera_model=False,
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_create_digraph_NH3():
max_rop=max_rop,
folder_path=folder_path,
display_concentrations=False,
display_flux_ratio=False,
report_flux_ratio=False,
display_r_n_p=True,
)
assert os.path.isfile(os.path.join(folder_path, 'flux_diagram_0.01_s.png'))
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_create_digraph_N2H4():
max_rop=max_rop,
folder_path=folder_path,
display_concentrations=True,
display_flux_ratio=True,
report_flux_ratio=True,
display_r_n_p=True,
)
assert os.path.isfile(os.path.join(folder_path, 'flux_diagram_0.005_s.png'))
Expand Down Expand Up @@ -361,7 +361,8 @@ def test_get_flux_graph():
assert len(profile['ROPs']) == 36
flux_graph, nodes_to_explore, min_rop, max_rop = flux.get_flux_graph(profile=profile, observables=['H4N2(1)'])
if i == 0:
assert nodes_to_explore == {'H2(4)', 'H3N2(6)', 'ammonia(9)', 'H(3)', 'H2N2(7)', '2 NH2(5)', 'HN2(10)', 'N2(2)', 'NH2(5)'}
assert nodes_to_explore == {'H2(4)', '2 H3N2(6)', 'ammonia(9)', 'H(3)', 'H2N2(7)', '2 NH2(5)',
'HN2(10)', 'N2(2)', 'NH2(5)', 'H3N2(6)'}
assert almost_equal(min_rop, 3.27e-21, ratio=100)
assert almost_equal(max_rop, 20.3659, places=3)
assert list(flux_graph.keys()) == ['H4N2(1)', 'NH2(5)', 'HN2(10)', 'H(3)', 'H2N2(7)', 'H2(4)', 'ammonia(9)', 'H3N2(6)']
Expand Down Expand Up @@ -391,6 +392,21 @@ def test_get_opposite_rxn_species():
assert flux.get_opposite_rxn_species(rxn=rxn, spc='H(3)') == ['H2O2(18)']
assert flux.get_opposite_rxn_species(rxn=rxn, spc='H2O2(18)') == ['H(3)', 'HO2(6)']

rxn = 'HO2 + NO <=> NO2 + OH'
assert flux.get_opposite_rxn_species(rxn=rxn, spc='HO2') == ['NO2', 'OH']
assert flux.get_opposite_rxn_species(rxn=rxn, spc='NO') == ['NO2', 'OH']
assert flux.get_opposite_rxn_species(rxn=rxn, spc='NO2') == ['HO2', 'NO']
assert flux.get_opposite_rxn_species(rxn=rxn, spc='OH') == ['HO2', 'NO']


def test_count_species_in_well():
"""Test counting a species in a well"""
assert flux.count_species_in_well(well='H2O + H + H', spc='H2O') == 1
assert flux.count_species_in_well(well='H2O + H + H', spc='H') == 2
assert flux.count_species_in_well(well='H2O + H + H', spc='H2') == 0
assert flux.count_species_in_well(well='HO2 + NO', spc='NO') == 1
assert flux.count_species_in_well(well='NO2 + OH', spc='NO') == 0


def test_get_other_reactants_and_products():
"""Test getting the reactants and products other than a given species."""
Expand Down

0 comments on commit 868e2de

Please sign in to comment.