From 38adc464b35fe8608f170f02c9200a0c232c36c5 Mon Sep 17 00:00:00 2001 From: Henrik Finsberg Date: Sat, 29 Jun 2024 22:29:41 +0200 Subject: [PATCH] Move from setup.py to pyproject.toml and fix a lot of errors --- bench/winslow_python.py | 3 +- demo/gotran2c/demo.py | 4 +- gotran/algorithms/symbolicnewtonsolution.py | 4 +- gotran/codegeneration/algorithmcomponents.py | 43 ++----- .../codegeneration/avoidsympycontractions.py | 5 +- gotran/codegeneration/codecomponent.py | 71 ++++------- gotran/codegeneration/codegenerators.py | 118 ++++-------------- gotran/codegeneration/compilemodule.py | 7 +- gotran/codegeneration/latexcodegenerator.py | 23 +--- gotran/codegeneration/oderepresentation.py | 24 ++-- gotran/codegeneration/solvercomponents.py | 16 +-- gotran/common/options.py | 83 ++++-------- gotran/input/cellml.py | 56 +++------ gotran/input/mathml.py | 36 +----- gotran/model/expressions.py | 10 +- gotran/model/loadmodel.py | 5 +- gotran/model/ode.py | 87 +++++-------- gotran/model/odecomponent.py | 74 ++++------- gotran/model/odeobjects.py | 14 +-- gotran/model/utils.py | 9 +- gotran/scripts/cellml2gotran.py | 4 +- gotran/scripts/gotran2c.py | 13 +- gotran/scripts/gotran2cpp.py | 7 +- gotran/scripts/gotran2cuda.py | 13 +- gotran/scripts/gotran2dolfin.py | 3 +- gotran/scripts/gotran2julia.py | 3 +- gotran/scripts/gotran2latex.py | 7 +- gotran/scripts/gotran2matlab.py | 3 +- gotran/scripts/gotran2md.py | 10 +- gotran/scripts/gotran2opencl.py | 16 +-- gotran/scripts/gotran2py.py | 3 +- gotran/scripts/gotranprobe.py | 3 +- gotran/scripts/gotranrun.py | 45 ++----- pyproject.toml | 2 +- sandbox/cellml/test_cellml.py | 12 +- sandbox/gpu/cossplotter.py | 5 +- sandbox/gpu/cosstester.py | 23 +--- sandbox/gpu/cuda_vs_goss.py | 20 +-- sandbox/gpu/cudaodesystemsolver.py | 28 +---- sandbox/gpu/test_gpu.py | 9 +- sandbox/gpu/testeverything.py | 21 +--- tests/codegeneration/test_codegeneration.py | 12 +- tests/input/test_cellml.py | 10 +- tests/input/test_compile_module.py | 6 +- tests/model/test_ode.py | 20 +-- tests/model/test_odeobjects.py | 25 +--- 46 files changed, 265 insertions(+), 750 deletions(-) diff --git a/bench/winslow_python.py b/bench/winslow_python.py index 000e8b5..6d17f84 100644 --- a/bench/winslow_python.py +++ b/bench/winslow_python.py @@ -53,8 +53,7 @@ "import time", "t0 = time.time()", f"for i in range({times}):", - " dy = rhs(init_states, 0.0" - + (")" if gen.oderepr.optimization.parameter_numerals else ", parameters)"), + " dy = rhs(init_states, 0.0" + (")" if gen.oderepr.optimization.parameter_numerals else ", parameters)"), f"""print ""\" keep_intermediates = {keep} use_cse = {use_cse} diff --git a/demo/gotran2c/demo.py b/demo/gotran2c/demo.py index f0aced4..74f976b 100644 --- a/demo/gotran2c/demo.py +++ b/demo/gotran2c/demo.py @@ -61,10 +61,10 @@ def init_parameters(): def solve(t_start, t_end, dt, num_steps=None, method="fe"): parameters = init_parameters() - if type(dt) is not float: + if not isinstance(t_start, float): dt = float(dt) if num_steps is not None: - assert type(num_steps) is int + assert isinstance(num_steps, int) t_end = dt * num_steps else: num_steps = round((t_end - t_start) / dt) diff --git a/gotran/algorithms/symbolicnewtonsolution.py b/gotran/algorithms/symbolicnewtonsolution.py index ea20305..dbe0341 100644 --- a/gotran/algorithms/symbolicnewtonsolution.py +++ b/gotran/algorithms/symbolicnewtonsolution.py @@ -257,9 +257,7 @@ def sum_ders(ders): # Generate F using theta rule F_expr = [ - theta * expr - + (1 - theta) * expr.subs(subs) - - (sum_ders(ders) - sum_ders(ders).subs(subs)) / ode.dt + theta * expr + (1 - theta) * expr.subs(subs) - (sum_ders(ders) - sum_ders(ders).subs(subs)) / ode.dt for ders, expr in ode.get_derivative_expr(True) ] diff --git a/gotran/codegeneration/algorithmcomponents.py b/gotran/codegeneration/algorithmcomponents.py index 1260d33..aaac615 100644 --- a/gotran/codegeneration/algorithmcomponents.py +++ b/gotran/codegeneration/algorithmcomponents.py @@ -173,8 +173,7 @@ def componentwise_derivative(ode, indices, params=None, result_name="dy"): for index in indices: if index < 0 or index >= ode.num_full_states: error( - "Expected the passed indices to be between 0 and the " - "number of states in the ode, got {0}.".format(index), + "Expected the passed indices to be between 0 and the " "number of states in the ode, got {0}.".format(index), ) if index in registered: error(f"Index {index} appeared twice.") @@ -496,11 +495,7 @@ def __init__( sys.stdout.flush() for i, expr in enumerate(state_exprs): - states_syms = sorted( - (state_dict[sym], sym) - for sym in ode_primitives(expr.expr, time_sym) - if sym in state_dict - ) + states_syms = sorted((state_dict[sym], sym) for sym in ode_primitives(expr.expr, time_sym) if sym in state_dict) self.add_comment( f"Expressions for the sparse jacobian of state {expr.state.name}", @@ -559,10 +554,7 @@ def __init__( """ check_arg(jacobian, JacobianComponent) - descr = ( - "Compute the diagonal jacobian of the right hand side of the " - "{0} ODE".format(jacobian.root) - ) + descr = "Compute the diagonal jacobian of the right hand side of the " "{0} ODE".format(jacobian.root) super(DiagonalJacobianComponent, self).__init__( "DiagonalJacobian", jacobian.root, @@ -628,10 +620,7 @@ def __init__( """ timer = Timer("Computing jacobian action component") # noqa: F841 check_arg(jacobian, JacobianComponent) - descr = ( - "Compute the jacobian action of the right hand side of the " - "{0} ODE".format(jacobian.root) - ) + descr = "Compute the jacobian action of the right hand side of the " "{0} ODE".format(jacobian.root) super(JacobianActionComponent, self).__init__( "JacobianAction", jacobian.root, @@ -696,9 +685,8 @@ def __init__( """ timer = Timer("Computing jacobian action component") # noqa: F841 check_arg(diagonal_jacobian, DiagonalJacobianComponent) - descr = ( - "Compute the diagonal jacobian action of the right hand side " - "of the {0} ODE".format(diagonal_jacobian.root) + descr = "Compute the diagonal jacobian action of the right hand side " "of the {0} ODE".format( + diagonal_jacobian.root ) super(DiagonalJacobianActionComponent, self).__init__( "DiagonalJacobianAction", @@ -851,9 +839,7 @@ def add_intermediate_if_changed(jac, jac_ij, i, j): # Store factorized jacobian self.factorized_jacobian = jac - self.num_nonzero = sum( - not jac[i, j].is_zero for i in range(n) for j in range(n) - ) + self.num_nonzero = sum(not jac[i, j].is_zero for i in range(n) for j in range(n)) # No need to call recreate body expressions self.body_expressions = self.ode_objects @@ -895,10 +881,7 @@ def __init__( timer = Timer("Computing forward backward substituion component") # noqa: F841 check_arg(factorized, FactorizedJacobianComponent) jacobian_name = list(factorized.shapes.keys())[0] - descr = ( - "Symbolically forward backward substitute linear system " - "of {0} ODE".format(factorized.root) - ) + descr = "Symbolically forward backward substitute linear system " "of {0} ODE".format(factorized.root) super(ForwardBackwardSubstitutionComponent, self).__init__( "ForwardBackwardSubst", factorized.root, @@ -969,11 +952,7 @@ def __init__( dx[i] = self.add_indexed_expression(result_name, i, dx[i] / jac[i, i]) # No need to call recreate body expressions - self.body_expressions = [ - obj - for obj in self.ode_objects - if isinstance(obj, (IndexedExpression, Comment)) - ] + self.body_expressions = [obj for obj in self.ode_objects if isinstance(obj, (IndexedExpression, Comment))] self.results = [result_name] self.used_states = set() @@ -1106,9 +1085,7 @@ def __init__(self, ode): # Collect all expanded state expressions org_state_expressions = ode.state_expressions - expanded_state_exprs = [ - ode.expanded_expressions[obj.name] for obj in org_state_expressions - ] + expanded_state_exprs = [ode.expanded_expressions[obj.name] for obj in org_state_expressions] # Call sympy common sub expression reduction cse_exprs, cse_state_exprs = cse( diff --git a/gotran/codegeneration/avoidsympycontractions.py b/gotran/codegeneration/avoidsympycontractions.py index a472a68..3a09c26 100644 --- a/gotran/codegeneration/avoidsympycontractions.py +++ b/gotran/codegeneration/avoidsympycontractions.py @@ -83,10 +83,7 @@ def _function_new(cls, *args, **options): # it work with NumPy's functions like vectorize(). The ideal # solution would be just to attach metadata to the exception # and change NumPy to take advantage of this. - temp = ( - "%(name)s takes exactly %(args)s " - "argument%(plural)s (%(given)s given)" - ) + temp = "%(name)s takes exactly %(args)s " "argument%(plural)s (%(given)s given)" raise TypeError( temp % { diff --git a/gotran/codegeneration/codecomponent.py b/gotran/codegeneration/codecomponent.py index 5e3d8a6..e8ad6cc 100644 --- a/gotran/codegeneration/codecomponent.py +++ b/gotran/codegeneration/codecomponent.py @@ -197,15 +197,17 @@ def add_indexed_expression( # Check that provided indices fit with the registered shape if len(self.shapes[basename]) > len(indices): error( - "Shape mismatch between indices {0} and registered " - "shape for {1}{2}".format(indices, basename, self.shapes[basename]), + "Shape mismatch between indices {0} and registered " "shape for {1}{2}".format( + indices, basename, self.shapes[basename] + ), ) for dim, (index, shape_ind) in enumerate(zip(indices, self.shapes[basename])): if index >= shape_ind: error( - "Indices must be smaller or equal to the shape. Mismatch " - "in dim {0}: {1}>={2}".format(dim + 1, index, shape_ind), + "Indices must be smaller or equal to the shape. Mismatch " "in dim {0}: {1}>={2}".format( + dim + 1, index, shape_ind + ), ) # Create the indexed expression @@ -270,15 +272,17 @@ def add_indexed_object(self, basename, indices, add_offset=False): # Check that provided indices fit with the registered shape if len(self.shapes[basename]) > len(indices): error( - "Shape mismatch between indices {0} and registered " - "shape for {1}{2}".format(indices, basename, self.shapes[basename]), + "Shape mismatch between indices {0} and registered " "shape for {1}{2}".format( + indices, basename, self.shapes[basename] + ), ) for dim, (index, shape_ind) in enumerate(zip(indices, self.shapes[basename])): if index >= shape_ind: error( - "Indices must be smaller or equal to the shape. Mismatch " - "in dim {0}: {1}>={2}".format(dim + 1, index, shape_ind), + "Indices must be smaller or equal to the shape. Mismatch " "in dim {0}: {1}>={2}".format( + dim + 1, index, shape_ind + ), ) # Create IndexedObject @@ -301,11 +305,7 @@ def indexed_objects(self, *basenames): """ if not basenames: basenames = list(self.shapes.keys()) - return [ - obj - for obj in self.ode_objects - if isinstance(obj, IndexedObject) and obj.basename in basenames - ] + return [obj for obj in self.ode_objects if isinstance(obj, IndexedObject) and obj.basename in basenames] def _init_param_state_replace_dict(self): """ @@ -383,23 +383,19 @@ def _init_param_state_replace_dict(self): # If not having named parameters if param_repr == "numerals": - param_state_replace_dict.update( - (param.sym, param.init) for param in self.root.parameters - ) + param_state_replace_dict.update((param.sym, param.init) for param in self.root.parameters) elif param_repr == "array": self.shapes[param_name] = (self.root.num_parameters,) if field_parameters: self.shapes["field_" + param_name] = (len(field_parameters),) param_state_replace_dict.update( - (param.sym, indexed.sym) - for param, indexed in list(param_state_map["parameters"].items()) + (param.sym, indexed.sym) for param, indexed in list(param_state_map["parameters"].items()) ) if state_repr == "array": self.shapes[state_name] = (self.root.num_full_states,) param_state_replace_dict.update( - (state.sym, indexed.sym) - for state, indexed in list(param_state_map["states"].items()) + (state.sym, indexed.sym) for state, indexed in list(param_state_map["states"].items()) ) param_state_replace_dict[self.root._time.sym] = sp.Symbol(time_name) @@ -433,15 +429,11 @@ def _expanded_result_expressions(self, **results): # A map between result expression and result name result_names = dict( - (result_expr, result_name) - for result_name, result_exprs in list(results.items()) - for result_expr in result_exprs + (result_expr, result_name) for result_name, result_exprs in list(results.items()) for result_expr in result_exprs ) # The expanded result expressions - expanded_result_exprs = [ - self.root.expanded_expression(obj) for obj in orig_result_expressions - ] + expanded_result_exprs = [self.root.expanded_expression(obj) for obj in orig_result_expressions] # Set shape for result expressions for result_name, result_expressions in list(results.items()): @@ -476,8 +468,7 @@ def _body_from_cse(self, **results): if might_take_time: info( - "Computing common sub expressions for {0}. Might take " - "some time...".format(self.name), + "Computing common sub expressions for {0}. Might take " "some time...".format(self.name), ) sys.stdout.flush() @@ -521,11 +512,7 @@ def _body_from_cse(self, **results): result_expr_map[result_expr].append( ( result_names[orig_result_expr], - ( - orig_result_expr.indices - if isinstance(orig_result_expr, IndexedExpression) - else ind - ), + (orig_result_expr.indices if isinstance(orig_result_expr, IndexedExpression) else ind), ), ) @@ -716,9 +703,7 @@ def _recreate_body(self, body_expressions, **results): # A map between result expression and result name result_names = dict( - (result_expr, result_name) - for result_name, result_exprs in list(results.items()) - for result_expr in result_exprs + (result_expr, result_name) for result_name, result_exprs in list(results.items()) for result_expr in result_exprs ) timer = Timer(f"Recreate body expressions for {self.name}") # noqa: F841 @@ -760,12 +745,8 @@ def _recreate_body(self, body_expressions, **results): replaced_expr_map = OrderedDict() new_body_expressions = [] - present_ode_objects = dict( - (state.name, state) for state in self.root.full_states - ) - present_ode_objects.update( - (param.name, param) for param in self.root.parameters - ) + present_ode_objects = dict((state.name, state) for state in self.root.full_states) + present_ode_objects.update((param.name, param) for param in self.root.parameters) old_present_ode_objects = present_ode_objects.copy() def store_expressions(expr, new_expr): @@ -858,11 +839,7 @@ def store_expressions(expr, new_expr): # index information so that the index previously available # for this expressions gets available at the last expressions # the present expression is used in. - if ( - isinstance(dep_expr, IndexedExpression) - and dep_expr.basename == body_name - and "reused" in body_repr - ): + if isinstance(dep_expr, IndexedExpression) and dep_expr.basename == body_name and "reused" in body_repr: ind = dep_expr.indices[0] # Remove available index information diff --git a/gotran/codegeneration/codegenerators.py b/gotran/codegeneration/codegenerators.py index 7774088..dd44ff0 100644 --- a/gotran/codegeneration/codegenerators.py +++ b/gotran/codegeneration/codegenerators.py @@ -419,17 +419,13 @@ def _is_number(num_str): if not is_comment and ( '"' in line_stump[-1] and not ( - '\\"' in line_stump[-1] - or '"""' in line_stump[-1] - or re.search(_re_str, line_stump[-1]) + '\\"' in line_stump[-1] or '"""' in line_stump[-1] or re.search(_re_str, line_stump[-1]) ) ): inside_str = not inside_str # Check line length - line_length += ( - len(line_stump[-1]) + 1 + (is_comment and not first_line) * (len(cls.comment) + 1) - ) + line_length += len(line_stump[-1]) + 1 + (is_comment and not first_line) * (len(cls.comment) + 1) # If we are inside a str and at the end of line add if inside_str and not is_comment: @@ -444,9 +440,7 @@ def _is_number(num_str): # If it is the last line stump add line ending otherwise # line continuation sign - ret_lines[-1] = ret_lines[-1] + (not is_comment) * ( - cls.line_cont if splitted_line else line_ending - ) + ret_lines[-1] = ret_lines[-1] + (not is_comment) * (cls.line_cont if splitted_line else line_ending) first_line = False else: @@ -589,9 +583,7 @@ def _init_arguments(self, comp): # If all states are used if len(used_states) == len(comp.root.full_states): body_lines.append( - ", ".join(state.name for i, state in enumerate(comp.root.full_states)) - + " = " - + states_name, + ", ".join(state.name for i, state in enumerate(comp.root.full_states)) + " = " + states_name, ) # If only a limited number of states are used @@ -605,11 +597,7 @@ def _init_arguments(self, comp): ) # Add parameters code if not numerals - if ( - "p" in default_arguments - and params.parameters.representation in ["named", "array"] - and used_parameters - ): + if "p" in default_arguments and params.parameters.representation in ["named", "array"] and used_parameters: parameters_name = params.parameters.array_name body_lines.append("") body_lines.append("# Assign parameters") @@ -620,9 +608,7 @@ def _init_arguments(self, comp): # If all parameters are used if len(used_parameters) == len(comp.root.parameters): body_lines.append( - ", ".join(param.name for i, param in enumerate(used_parameters)) - + " = " - + parameters_name, + ", ".join(param.name for i, param in enumerate(used_parameters)) + " = " + parameters_name, ) # If only a limited number of states are used @@ -848,8 +834,7 @@ def init_states_code(self, ode, indent=0, perform_range_check=False): "ind, range = state_ind[state_name]", "if value not in range:", [ - "raise ValueError(\"While setting '{0}' {1}\".format(" - "state_name, range.format_not_in(value)))", + "raise ValueError(\"While setting '{0}' {1}\".format(" "state_name, range.format_not_in(value)))", ], "", "# Assign value", @@ -933,9 +918,7 @@ def init_parameters_code(self, ode, indent=0, perform_range_check=False): else: body_lines.append( "param_ind = dict([{0}])".format( - ", ".join( - '("{0}", {1})'.format(param.param.name, i) for i, param in enumerate(parameters) - ), + ", ".join('("{0}", {1})'.format(param.param.name, i) for i, param in enumerate(parameters)), ), ) body_lines.append("") @@ -951,8 +934,7 @@ def init_parameters_code(self, ode, indent=0, perform_range_check=False): "ind, range = param_ind[param_name]", "if value not in range:", [ - "raise ValueError(\"While setting '{0}' {1}\".format(" - "param_name, range.format_not_in(value)))", + "raise ValueError(\"While setting '{0}' {1}\".format(" "param_name, range.format_not_in(value)))", ], "", "# Assign value", @@ -1411,11 +1393,7 @@ def add_obj(obj, i, array_name, add_offset=False): add_obj(state, i, states_name, state_offset) # Add parameters code if not numerals - if ( - "p" in default_arguments - and params.parameters.representation in ["named", "array"] - and used_parameters - ): + if "p" in default_arguments and params.parameters.representation in ["named", "array"] and used_parameters: # Generate parameters assign code if params.parameters.representation == "named": parameters_name = params.parameters.array_name @@ -1444,11 +1422,7 @@ def add_obj(obj, i, array_name, add_offset=False): add_obj(param, i, field_parameters_name, field_parameter_offset) # If using an array for the body variables and b is not passed as argument - if ( - params.body.representation != "named" - and not params.body.in_signature - and params.body.array_name in comp.shapes - ): + if params.body.representation != "named" and not params.body.in_signature and params.body.array_name in comp.shapes: body_name = params.body.array_name body_lines.append("") body_lines.append(f"// Body array {body_name}") @@ -2041,11 +2015,7 @@ def add_obj(obj, i, array_name, add_offset=False): add_state_obj(state, self._state_enum_val(state), states_name) # Add parameters code if not numerals - if ( - "p" in default_arguments - and params.parameters.representation in ["named", "array"] - and used_parameters - ): + if "p" in default_arguments and params.parameters.representation in ["named", "array"] and used_parameters: # Generate parameters assign code if params.parameters.representation == "named": parameters_name = params.parameters.array_name @@ -2071,11 +2041,7 @@ def add_obj(obj, i, array_name, add_offset=False): add_obj(param, i, field_parameters_name, field_parameter_offset) # If using an array for the body variables and b is not passed as argument - if ( - params.body.representation != "named" - and not params.body.in_signature - and params.body.array_name in comp.shapes - ): + if params.body.representation != "named" and not params.body.in_signature and params.body.array_name in comp.shapes: body_name = params.body.array_name body_lines.append("") body_lines.append(f"// Body array {body_name}") @@ -2688,11 +2654,7 @@ def add_obj(obj, i, array_name, add_offset=False): add_state_obj(state, self._state_enum_val(state), states_name) # Add parameters code if not numerals - if ( - "p" in default_arguments - and params.parameters.representation in ["named", "array"] - and used_parameters - ): + if "p" in default_arguments and params.parameters.representation in ["named", "array"] and used_parameters: # Generate parameters assign code if params.parameters.representation == "named": parameters_name = params.parameters.array_name @@ -2718,11 +2680,7 @@ def add_obj(obj, i, array_name, add_offset=False): add_obj(param, i, field_parameters_name, field_parameter_offset) # If using an array for the body variables and b is not passed as argument - if ( - params.body.representation != "named" - and not params.body.in_signature - and params.body.array_name in comp.shapes - ): + if params.body.representation != "named" and not params.body.in_signature and params.body.array_name in comp.shapes: body_name = params.body.array_name body_lines.append("") body_lines.append(f"// Body array {body_name}") @@ -3422,11 +3380,7 @@ def _init_arguments(self, comp): ) # Add parameters code if not numerals - if ( - "p" in default_arguments - and params.parameters.representation in ["named", "array"] - and used_parameters - ): + if "p" in default_arguments and params.parameters.representation in ["named", "array"] and used_parameters: parameters_name = params.parameters.array_name body_lines.append("") body_lines.append("% Assign parameters") @@ -3686,9 +3640,7 @@ def function_code(self, comp, indent=0, include_signature=True): body_lines.append( "return {0}".format( ", ".join( - ("{0}[0]" if comp.shapes[result_name][0] == 1 else "dolfin.as_vector({0})").format( - result_name - ) + ("{0}[0]" if comp.shapes[result_name][0] == 1 else "dolfin.as_vector({0})").format(result_name) for result_name in comp.results ), ), @@ -3754,8 +3706,7 @@ def init_states_code(self, ode, indent=0): "ind, range_check, not_in_format = state_ind[state_name]", "if not range_check(value):", [ - "raise ValueError(\"While setting '{0}' {1}\".format(" - "state_name, not_in_format % str(value)))", + "raise ValueError(\"While setting '{0}' {1}\".format(" "state_name, not_in_format % str(value)))", ], "", "# Assign value", @@ -3822,8 +3773,7 @@ def init_parameters_code(self, ode, indent=0): "ind, range_check, not_in_format = parameter_ind[param_name]", "if not range_check(value):", [ - "raise ValueError(\"While setting '{0}' {1}\".format(" - "param_name, not_in_format % str(value)))", + "raise ValueError(\"While setting '{0}' {1}\".format(" "param_name, not_in_format % str(value)))", ], "", "# Assign value", @@ -3974,9 +3924,7 @@ def _init_arguments(self, comp): # If all states are used if len(used_states) == len(comp.root.full_states): body_lines.append( - ", ".join(state.name for i, state in enumerate(comp.root.full_states)) - + " = " - + states_name, + ", ".join(state.name for i, state in enumerate(comp.root.full_states)) + " = " + states_name, ) # If only a limited number of states are used @@ -3990,11 +3938,7 @@ def _init_arguments(self, comp): ) # Add parameters code if not numerals - if ( - "p" in default_arguments - and params.parameters.representation in ["named", "array"] - and used_parameters - ): + if "p" in default_arguments and params.parameters.representation in ["named", "array"] and used_parameters: parameters_name = params.parameters.array_name body_lines.append("") body_lines.append("# Assign parameters") @@ -4005,9 +3949,7 @@ def _init_arguments(self, comp): # If all parameters are used if len(used_parameters) == len(comp.root.parameters): body_lines.append( - ", ".join(param.name for i, param in enumerate(used_parameters)) - + " = " - + parameters_name, + ", ".join(param.name for i, param in enumerate(used_parameters)) + " = " + parameters_name, ) # If only a limited number of states are used @@ -4144,9 +4086,7 @@ def init_states_code(self, ode, indent=0): # for i, state in enumerate(states, start=1)))) body_lines.append( "state_ind = Dict({0})".format( - ", ".join( - '"{0}" => {1}'.format(state.param.name, i) for i, state in enumerate(states, start=1) - ), + ", ".join('"{0}" => {1}'.format(state.param.name, i) for i, state in enumerate(states, start=1)), ), ) body_lines.append("") @@ -4200,9 +4140,7 @@ def init_parameters_code(self, ode, indent=0): body_lines.append( "param_ind = Dict({0})".format( - ", ".join( - '"{0}" => {1}'.format(state.param.name, i) for i, state in enumerate(parameters, start=1) - ), + ", ".join('"{0}" => {1}'.format(state.param.name, i) for i, state in enumerate(parameters, start=1)), ), ) body_lines.append("") @@ -4239,9 +4177,7 @@ def state_name_to_index_code(self, ode, indent=0): body_lines.append( "state_inds = Dict({0})".format( - ", ".join( - '"{0}" => {1}'.format(state.param.name, i) for i, state in enumerate(states, start=1) - ), + ", ".join('"{0}" => {1}'.format(state.param.name, i) for i, state in enumerate(states, start=1)), ), ) @@ -4275,9 +4211,7 @@ def param_name_to_index_code(self, ode, indent=0): body_lines = [] body_lines.append( "param_inds = Dict({0})".format( - ", ".join( - '"{0}" => {1}'.format(param.param.name, i) for i, param in enumerate(parameters, start=1) - ), + ", ".join('"{0}" => {1}'.format(param.param.name, i) for i, param in enumerate(parameters, start=1)), ), ) diff --git a/gotran/codegeneration/compilemodule.py b/gotran/codegeneration/compilemodule.py index 4f9af0a..c485b9b 100644 --- a/gotran/codegeneration/compilemodule.py +++ b/gotran/codegeneration/compilemodule.py @@ -304,12 +304,7 @@ def parse_monitor_declaration(ode, args, args_doc, params, monitored): def signature(ode, monitored, params, languange) -> str: return hashlib.sha1( str( - ode.signature() - + str(monitored) - + repr(params) - + languange - + __version__ - + cppyy_version(), + ode.signature() + str(monitored) + repr(params) + languange + __version__ + cppyy_version(), ).encode("utf-8"), ).hexdigest() diff --git a/gotran/codegeneration/latexcodegenerator.py b/gotran/codegeneration/latexcodegenerator.py index 294c23e..1bb2ad1 100644 --- a/gotran/codegeneration/latexcodegenerator.py +++ b/gotran/codegeneration/latexcodegenerator.py @@ -178,8 +178,7 @@ def _default_latex_params(): # "blank") params["math_font_size"] = Param( 0.0, - description="Set font size for mathematical expressions in " - "LaTeX document. Uses global font size if left blank", + description="Set font size for mathematical expressions in " "LaTeX document. Uses global font size if left blank", ) # Toggle bold equation labels @@ -308,9 +307,7 @@ def generate(self, params=None): global_opts = self.format_global_options(_global_opts) if not params.preamble: - latex_output = ( - self.generate_parameter_table() + self.generate_state_table() + self.generate_components() - ) + latex_output = self.generate_parameter_table() + self.generate_state_table() + self.generate_components() else: document_opts = self.format_options( override=["font_size", "landscape", "page_numbers"], @@ -320,9 +317,7 @@ def generate(self, params=None): PKGS=self.format_packages(self.packages), PREOPTS=global_opts, OPTS=document_opts["begin"], - BODY=self.generate_parameter_table() - + self.generate_state_table() - + self.generate_components(), + BODY=self.generate_parameter_table() + self.generate_state_table() + self.generate_components(), ENDOPTS=document_opts["end"], ) @@ -370,9 +365,7 @@ def generate_components(self, params=None): params = params if params else self.params components_str = "" comp_template = ( - "{LABEL}\n\\label{{comp:{LABELID}}}\n" - "\\begin{{dgroup{SUBNUM}}}\n" - "{BODY}\\end{{dgroup{SUBNUM}}}\n" + "{LABEL}\n\\label{{comp:{LABELID}}}\n" "\\begin{{dgroup{SUBNUM}}}\n" "{BODY}\\end{{dgroup{SUBNUM}}}\n" ) eqn_template = " \\begin{{dmath}}\n \\label{{eq:{0}}}\n" " {1} = {2}\\\\\n \\end{{dmath}}\n" @@ -381,9 +374,7 @@ def generate_components(self, params=None): for comp in self.ode.components: if comp.rates: body = [ - obj - for obj in comp.ode_objects - if isinstance(obj, Expression) and not isinstance(obj, StateDerivative) + obj for obj in comp.ode_objects if isinstance(obj, Expression) and not isinstance(obj, StateDerivative) ] else: body = [obj for obj in comp.ode_objects if isinstance(obj, Expression)] @@ -520,9 +511,7 @@ def format_options(self, exclude=None, override=None, params=None): begin_str = end_str = "" - if opts.page_columns > 1 and ( - ("page_columns" not in exclude and not override) or "page_columns" in override - ): + if opts.page_columns > 1 and (("page_columns" not in exclude and not override) or "page_columns" in override): begin_str = f"\\begin{{multicols}}{{{opts.page_columns}}}\n" + begin_str end_str += "\\end{multicols}\n" diff --git a/gotran/codegeneration/oderepresentation.py b/gotran/codegeneration/oderepresentation.py index 839dec4..d45cb51 100644 --- a/gotran/codegeneration/oderepresentation.py +++ b/gotran/codegeneration/oderepresentation.py @@ -119,8 +119,7 @@ def _default_params(exclude=None): # only when keep_intermediates is false params["use_cse"] = Param( False, - description="Use sympy common sub expression " - "simplifications, only when keep_intermediates is false", + description="Use sympy common sub expression " "simplifications, only when keep_intermediates is false", ) if "generate_jacobian" not in exclude: @@ -366,9 +365,7 @@ def _compute_jacobian_action(self): # Create a state vector ode_states = sp.Matrix(ode.num_states, 1, lambda i, j: self.ode.states[i]) self._jacobian_action_vec = self._jacobian_mat * ode_states - self._jacobian_action_expr = [ - jac_ac_expr.subs(self._jacobian_expr) for jac_ac_expr in self._jacobian_action_vec - ] + self._jacobian_action_expr = [jac_ac_expr.subs(self._jacobian_expr) for jac_ac_expr in self._jacobian_action_vec] def _compute_jacobian_action_cse(self): self._compute_jacobian_action() @@ -377,8 +374,9 @@ def _compute_jacobian_action_cse(self): return info( - "Calculating jacobian action common sub expressions for {0} entries. " - "May take some time...".format(len(self._jacobian_action_expr)), + "Calculating jacobian action common sub expressions for {0} entries. " "May take some time...".format( + len(self._jacobian_action_expr) + ), ) sys.stdout.flush() @@ -827,8 +825,7 @@ def symbol_subs(self): subs.extend((param.sym, param.init) for param in self.ode.parameters) elif not self.optimization.use_parameter_names: subs.extend( - (param.sym, sp.Symbol("parameters" + self.index(ind))) - for ind, param in enumerate(self.ode.parameters) + (param.sym, sp.Symbol("parameters" + self.index(ind))) for ind, param in enumerate(self.ode.parameters) ) elif self._parameter_prefix: subs.extend( @@ -841,10 +838,7 @@ def symbol_subs(self): # Deal with state subs if not self.optimization.use_state_names: - subs.extend( - (state.sym, sp.Symbol("states" + self.index(ind))) - for ind, state in enumerate(self.ode.states) - ) + subs.extend((state.sym, sp.Symbol("states" + self.index(ind))) for ind, state in enumerate(self.ode.states)) elif self._state_prefix: subs.extend( @@ -871,9 +865,7 @@ def iter_derivative_expr(self): # No intermediates and no CSE if not self.optimization.use_cse: - return ( - (derivatives, self.subs(expr)) for derivatives, expr in self.ode.get_derivative_expr(True) - ) + return ((derivatives, self.subs(expr)) for derivatives, expr in self.ode.get_derivative_expr(True)) # Use CSE else: diff --git a/gotran/codegeneration/solvercomponents.py b/gotran/codegeneration/solvercomponents.py index b92283c..4776929 100644 --- a/gotran/codegeneration/solvercomponents.py +++ b/gotran/codegeneration/solvercomponents.py @@ -221,9 +221,7 @@ def __init__(self, ode, function_name="forward_explicit_euler", params=None): error("Cannot generate an explicit Euler forward step for a DAE.") # Call base class using empty result_expressions - descr = ( - f"Compute a forward step using the explicit Euler scheme to the {ode} ODE" - ) + descr = f"Compute a forward step using the explicit Euler scheme to the {ode} ODE" super(ExplicitEuler, self).__init__( "ExplicitEuler", ode, @@ -562,10 +560,7 @@ def __init__( check_arg(ode, ODE) # Call base class using empty result_expressions - descr = ( - "Compute a forward step using the generalised Rush-Larsen (GRL1) scheme to the " - "{0} ODE".format(ode) - ) + descr = "Compute a forward step using the generalised Rush-Larsen (GRL1) scheme to the " "{0} ODE".format(ode) super(GeneralizedRushLarsen, self).__init__( "GeneralizedRushLarsen", ode, @@ -687,7 +682,7 @@ def __init__( state_names = [s.name for s in self.root.full_states] if stiff_state_variables is None: stiff_state_variables = [] - elif type(stiff_state_variables) is str: + elif isinstance(stiff_state_variables, str): stiff_state_variables = stiff_state_variables.split(",") for s in stiff_state_variables: @@ -799,10 +794,7 @@ def __init__( error("Cannot generate an explicit Euler forward step for a DAE.") # Call base class using empty result_expressions - descr = ( - "Compute a forward step using the simplified implicit Euler" - "scheme to the {0} ODE".format(ode) - ) + descr = "Compute a forward step using the simplified implicit Euler" "scheme to the {0} ODE".format(ode) super(SimplifiedImplicitEuler, self).__init__( "SimplifiedImplicitEuler", ode, diff --git a/gotran/common/options.py b/gotran/common/options.py index 7834eb8..a0b7f1c 100644 --- a/gotran/common/options.py +++ b/gotran/common/options.py @@ -41,8 +41,7 @@ default_arguments=OptionParam( "stp", ["tsp", "stp", "spt", "ts", "st"], - description="Default input argument order: " - "s=states, p=parameters, t=time", + description="Default input argument order: " "s=states, p=parameters, t=time", ), # Parameter for the time parameter name time=ParameterDict(name=Param("t", description="Name of time argument")), @@ -93,18 +92,15 @@ ), field_array_name=Param( "field_parameters", - description="The name of the array " - "representing the field parameters.", + description="The name of the array " "representing the field parameters.", ), add_offset=Param( False, - description="If true an offset will be " - "added to the index of each parameter", + description="If true an offset will be " "added to the index of each parameter", ), add_field_offset=Param( False, - description="If true an offset will be " - "added to the index of each field parameter", + description="If true an offset will be " "added to the index of each field parameter", ), ), # Parameters for code generation of states @@ -128,27 +124,22 @@ ), add_offset=Param( False, - description="If true an offset will be " - "added to the index of each state", + description="If true an offset will be " "added to the index of each state", ), ), # Parameters for code generation of body expressions body=ParameterDict( use_cse=Param( False, - description="If true will the body be " - "optimized using SymPy common sub expression " - "extraction.", + description="If true will the body be " "optimized using SymPy common sub expression " "extraction.", ), use_enum=Param( False, - description="If true use enumeration" - "data types instead of indexing.", + description="If true use enumeration" "data types instead of indexing.", ), in_signature=Param( False, - description="If true the body argument " - "will be included in the signature.", + description="If true the body argument " "will be included in the signature.", ), representation=OptionParam( "named", @@ -176,8 +167,7 @@ rhs=ParameterDict( generate=Param( True, - description="Generate code for the " - "evaluation of the right hand side evaluation.", + description="Generate code for the " "evaluation of the right hand side evaluation.", ), function_name=Param( "rhs", @@ -191,8 +181,7 @@ monitored=ParameterDict( generate=Param( True, - description="Generate code for the " - "evaluation of monitored intermediates.", + description="Generate code for the " "evaluation of monitored intermediates.", ), function_name=Param( "monitor", @@ -206,9 +195,7 @@ jacobian=ParameterDict( generate=Param( False, - description="Generate code for the " - "evaluation of the jacobian of the right hand " - "side.", + description="Generate code for the " "evaluation of the jacobian of the right hand " "side.", ), function_name=Param( "compute_jacobian", @@ -222,8 +209,7 @@ lu_factorization=ParameterDict( generate=Param( False, - description="Generate code for " - "symbolicly factorize the jacobian.", + description="Generate code for " "symbolicly factorize the jacobian.", ), function_name=Param( "lu_factorize", @@ -233,9 +219,7 @@ forward_backward_subst=ParameterDict( generate=Param( False, - description="Generate code for the " - "symbolic forward backward substitution of the " - "jacobian.", + description="Generate code for the " "symbolic forward backward substitution of the " "jacobian.", ), function_name=Param( "forward_backward_subst", @@ -253,8 +237,7 @@ componentwise_rhs_evaluation=ParameterDict( generate=Param( False, - description="If true, generate code for " - "computing componentwise evaluation of the rhs.", + description="If true, generate code for " "computing componentwise evaluation of the rhs.", ), function_name=Param( "componentwise_rhs", @@ -264,9 +247,7 @@ linearized_rhs_evaluation=ParameterDict( generate=Param( False, - description="If true, generate code for " - "computing linearized evaluation of linear rhs " - "terms.", + description="If true, generate code for " "computing linearized evaluation of linear rhs " "terms.", ), function_name=Param( "linearized_rhs", @@ -274,14 +255,11 @@ ), include_rhs=Param( False, - description="If True the rhs will be " - "included as a result argument.", + description="If True the rhs will be " "included as a result argument.", ), only_linear=Param( True, - description="If True only linearized " - "expressions for the linear derivatives will " - "be generated.", + description="If True only linearized " "expressions for the linear derivatives will " "be generated.", ), result_names=Param( ["linearized", "rhs"], @@ -305,8 +283,7 @@ explicit_euler=ParameterDict( generate=Param( False, - description="If true, generate code for " - "solving an ODE using explicit Euler method.", + description="If true, generate code for " "solving an ODE using explicit Euler method.", ), function_name=Param( "forward_explicit_euler", @@ -316,8 +293,7 @@ rush_larsen=ParameterDict( generate=Param( False, - description="If true, generate code for " - "solving an ODE using Rush Larsen method.", + description="If true, generate code for " "solving an ODE using Rush Larsen method.", ), function_name=Param( "forward_rush_larsen", @@ -327,16 +303,13 @@ 1e-8, gt=0, lt=1.0, - description="Value to " - "safeguard the evaluation of the rush larsen " - "step.", + description="Value to " "safeguard the evaluation of the rush larsen " "step.", ), ), generalized_rush_larsen=ParameterDict( generate=Param( False, - description="If true, generate code for " - "solving an ODE using generalized Rush Larsen method.", + description="If true, generate code for " "solving an ODE using generalized Rush Larsen method.", ), function_name=Param( "forward_generalized_rush_larsen", @@ -346,9 +319,7 @@ 1e-8, gt=0, lt=1.0, - description="Value to " - "safeguard the evaluation of the rush larsen " - "step.", + description="Value to " "safeguard the evaluation of the rush larsen " "step.", ), ), hybrid_generalized_rush_larsen=ParameterDict( @@ -365,9 +336,7 @@ 1e-8, gt=0, lt=1.0, - description="Value to " - "safeguard the evaluation of the rush larsen " - "step.", + description="Value to " "safeguard the evaluation of the rush larsen " "step.", ), stiff_states=Param( "", @@ -380,8 +349,7 @@ simplified_implicit_euler=ParameterDict( generate=Param( False, - description="If true, generate code for " - "solving an ODE using Rush Larsen method.", + description="If true, generate code for " "solving an ODE using Rush Larsen method.", ), function_name=Param( "forward_simplified_implicit_euler", @@ -407,8 +375,7 @@ grouping=OptionParam( "encapsulation", ["encapsulation", "containment"], - description="Determines what type of grouping " - "should be used when the cellml model is parsed.", + description="Determines what type of grouping " "should be used when the cellml model is parsed.", ), use_sympy_integers=Param( False, diff --git a/gotran/input/cellml.py b/gotran/input/cellml.py index 266f2b5..8c3e92e 100644 --- a/gotran/input/cellml.py +++ b/gotran/input/cellml.py @@ -159,9 +159,7 @@ def __init__(self, name, variables, equations, state_variables=None): self.variable_info[state] = _info self.variable_info["type"] = "state_variable" - self.parameters = OrderedDict( - (name, _info) for name, _info in list(variables.items()) if _info["init"] is not None - ) + self.parameters = OrderedDict((name, _info) for name, _info in list(variables.items()) if _info["init"] is not None) for param, _info in list(self.parameters.items()): self.variable_info[param] = _info @@ -249,9 +247,7 @@ def check_dependencies(self, component): assert isinstance(component, Component) if any(equation.name in self.used_variables for equation in component.equations): - dep_equations = [ - equation for equation in component.equations if equation.name in self.used_variables - ] + dep_equations = [equation for equation in component.equations if equation.name in self.used_variables] # Register mutual dependencies self.dependencies[component] = dep_equations @@ -315,8 +311,7 @@ def change_state_name(self, oldname, newname=None): # Update parameters self.state_variables = OrderedDict( - (newname if name == oldname else name, value) - for name, value in list(self.state_variables.items()) + (newname if name == oldname else name, value) for name, value in list(self.state_variables.items()) ) oldder = self.derivatives[oldname] @@ -493,10 +488,7 @@ def parse_documentation(self): # Clean up content content = ( - ("\n".join(cont.strip() for cont in content)) - .replace(" ", " ") - .replace(" .", ".") - .replace(" ,", ",") + ("\n".join(cont.strip() for cont in content)).replace(" ", " ").replace(" .", ".").replace(" ,", ",") ) break else: @@ -635,10 +627,7 @@ def check_and_register_component_variables( ), ) for change_comp in [comp, state_comp]: - if ( - change_comp.state_variables[name]["private"] - and change_comp.variable_info[der_name]["private"] - ): + if change_comp.state_variables[name]["private"] and change_comp.variable_info[der_name]["private"]: new_name = change_comp.change_state_name(name) if change_comp == state_comp: collected_states[new_name] = change_comp @@ -655,8 +644,7 @@ def check_and_register_component_variables( elif name in collected_parameters: param_comp = collected_parameters[name] begin_log( - "State name: '{0}' from component '{1}' is used as " - "parameter in component '{2}'.".format( + "State name: '{0}' from component '{1}' is used as " "parameter in component '{2}'.".format( name, comp.name, param_comp.name, @@ -686,8 +674,7 @@ def check_and_register_component_variables( elif name in collected_equations: eq_comp = collected_equations[name] begin_log( - "State name '{0}' from component '{1}' is used as " - "parameter in component '{2}'.".format( + "State name '{0}' from component '{1}' is used as " "parameter in component '{2}'.".format( name, comp.name, eq_comp.name, @@ -730,10 +717,7 @@ def check_and_register_component_variables( # If parameter is private we change that if comp.parameters[name]["private"]: name = comp.change_parameter_name(name) - elif ( - state_comp.state_variables[name]["private"] - and state_comp.variable_info[der_name]["private"] - ): + elif state_comp.state_variables[name]["private"] and state_comp.variable_info[der_name]["private"]: new_name = state_comp.change_state_name(name) collected_states.pop(name) collected_states[new_name] = state_comp @@ -750,8 +734,7 @@ def check_and_register_component_variables( elif name in collected_parameters: param_comp = collected_parameters[name] begin_log( - "Parameter name '{0}' from component '{1}' is used as " - "parameter in component '{2}'.".format( + "Parameter name '{0}' from component '{1}' is used as " "parameter in component '{2}'.".format( name, comp.name, param_comp.name, @@ -781,8 +764,7 @@ def check_and_register_component_variables( elif name in collected_equations: eq_comp = collected_equations[name] begin_log( - "Parameter name '{0}' from component '{1}' " - "is used as parameter in component '{2}'.".format( + "Parameter name '{0}' from component '{1}' " "is used as parameter in component '{2}'.".format( name, comp.name, eq_comp.name, @@ -837,10 +819,7 @@ def check_and_register_component_variables( # If equation is private we change that # if comp.variable_info[name]["private"]: # name = comp.change_equation_name(name) - if ( - state_comp.state_variables[name]["private"] - and state_comp.variable_info[der_name]["private"] - ): + if state_comp.state_variables[name]["private"] and state_comp.variable_info[der_name]["private"]: new_name = state_comp.change_state_name(name) collected_states.pop(name) collected_states[new_name] = state_comp @@ -858,8 +837,7 @@ def check_and_register_component_variables( elif name in collected_parameters: param_comp = collected_parameters[name] begin_log( - "Equation name '{0}' from component '{1}' is used as " - "parameter in component '{2}'.".format( + "Equation name '{0}' from component '{1}' is used as " "parameter in component '{2}'.".format( name, comp.name, param_comp.name, @@ -890,8 +868,7 @@ def check_and_register_component_variables( elif name in collected_equations: eq_comp = collected_equations[name] info( - "Equation name '{0}' from component '{1}' is used as " - "equation name in component '{2}'.".format( + "Equation name '{0}' from component '{1}' is used as " "equation name in component '{2}'.".format( name, comp.name, eq_comp.name, @@ -1129,8 +1106,7 @@ def simple_sort(components): import networkx as nx except ImportError: warning( - "networkx could not be imported. Circular " - "dependencies between components will not be sorted out.", + "networkx could not be imported. Circular " "dependencies between components will not be sorted out.", ) return sorted_components + circular_components @@ -1493,9 +1469,7 @@ def parse_components(self, targets): break # If variable is not intended out - if not ( - var_info.get("public_interface") == "out" or var_info.get("private_interface") == "out" - ): + if not (var_info.get("public_interface") == "out" or var_info.get("private_interface") == "out"): continue # Check if the oldname is used in any components diff --git a/gotran/input/mathml.py b/gotran/input/mathml.py index eb01c85..bf94397 100644 --- a/gotran/input/mathml.py +++ b/gotran/input/mathml.py @@ -173,21 +173,13 @@ def _parse_subtree(self, root, parent=None, first_operand=True): use_parent = True eq += [self._operators[op]] - eq += ( - ["("] * use_parent - + self._parse_subtree(root[0], op) - + [")"] * use_parent - ) + eq += ["("] * use_parent + self._parse_subtree(root[0], op) + [")"] * use_parent return eq else: # Binary operator eq += ["("] * use_parent + self._parse_subtree(root[0], op) for operand in root[1:]: - eq = ( - eq - + [self._operators[op]] - + self._parse_subtree(operand, op, first_operand=False) - ) + eq = eq + [self._operators[op]] + self._parse_subtree(operand, op, first_operand=False) eq = eq + [")"] * use_parent return eq @@ -237,11 +229,7 @@ def _parse_eq(self, operands, parent): return self._parse_conditional("Eq", operands, "eq") # Parsing assignment - return ( - self._parse_subtree(operands[0], "eq") - + [self["eq"]] - + self._parse_subtree(operands[1], "eq") - ) + return self._parse_subtree(operands[0], "eq") + [self["eq"]] + self._parse_subtree(operands[1], "eq") def _parse_pi(self, var, parent): return ["pi"] @@ -318,26 +306,12 @@ def _parse_piecewise(self, cases, parent=None): piece_children = list(cases[0]) cond = self._parse_subtree(piece_children[1], "piecewise") true = self._parse_subtree(piece_children[0]) - return ( - ["Conditional", "("] - + cond - + [", "] - + true - + [", "] - + self._parse_piecewise(cases[1:]) - + [")"] - ) + return ["Conditional", "("] + cond + [", "] + true + [", "] + self._parse_piecewise(cases[1:]) + [")"] class MathMLCPPParser(MathMLBaseParser): def _parse_power(self, operands): - return ( - ["pow", "("] - + self._parse_subtree(operands[0]) - + [", "] - + self._parse_subtree(operands[1]) - + [")"] - ) + return ["pow", "("] + self._parse_subtree(operands[0]) + [", "] + self._parse_subtree(operands[1]) + [")"] def _parse_piecewise(self, cases): if len(cases) == 2: diff --git a/gotran/model/expressions.py b/gotran/model/expressions.py index 8c1548e..c7671a7 100644 --- a/gotran/model/expressions.py +++ b/gotran/model/expressions.py @@ -59,8 +59,7 @@ def recreate_expression(expr, *replace_dicts, **kwargs): replace_type = kwargs.get("replace_type", "xreplace") if replace_type not in ["xreplace", "subs"]: error( - "Valid alternatives for replace_type is: 'xreplace', " - "'subs' got {0}".format(replace_type), + "Valid alternatives for replace_type is: 'xreplace', " "'subs' got {0}".format(replace_type), ) # First do the replacements @@ -151,9 +150,7 @@ def __init__(self, name, expr, dependent=None): # Deal with Subs in sympy expression for sub_expr in expr.atoms(sp.Subs): # deal with one Subs at a time - subs = dict( - (key, value) for key, value in zip(sub_expr.variables, sub_expr.point) - ) + subs = dict((key, value) for key, value in zip(sub_expr.variables, sub_expr.point)) expr = expr.subs(sub_expr, sub_expr.expr.xreplace(subs)) @@ -335,8 +332,7 @@ def __init__(self, der_expr, dep_var, expr, dependent=None): # Check that the der_expr is dependent on var if dep_var.sym not in der_expr.sym.args: error( - "Cannot create a DerivativeExpression as {0} is not " - "dependent on {1}".format(der_expr, dep_var), + "Cannot create a DerivativeExpression as {0} is not " "dependent on {1}".format(der_expr, dep_var), ) der_sym = sp.Derivative(der_expr.sym, dep_var.sym) diff --git a/gotran/model/loadmodel.py b/gotran/model/loadmodel.py index 398fae7..1863e96 100644 --- a/gotran/model/loadmodel.py +++ b/gotran/model/loadmodel.py @@ -456,8 +456,9 @@ def model_arguments(**kwargs): for key, value in list(kwargs.items()): if not isinstance(value, (float, int, str, Param)): error( - "expected only 'float', 'int', 'str' or 'Param', as model_arguments, " - "got: '{}' for '{}'".format(type(value).__name__, key), + "expected only 'float', 'int', 'str' or 'Param', as model_arguments, " "got: '{}' for '{}'".format( + type(value).__name__, key + ), ) if key not in load_arguments: diff --git a/gotran/model/ode.py b/gotran/model/ode.py index 235a912..1c3cf74 100644 --- a/gotran/model/ode.py +++ b/gotran/model/ode.py @@ -343,14 +343,13 @@ def add_comp_and_children(added, comp): if der_expr is None: if prefix: error( - "Could not find expression: " - "({0}){1} while adding " - "derivative".format(prefix, obj.der_expr), + "Could not find expression: " "({0}){1} while adding " "derivative".format( + prefix, obj.der_expr + ), ) else: error( - "Could not find expression: " - "{0} while adding derivative".format( + "Could not find expression: " "{0} while adding derivative".format( obj.der_expr, ), ) @@ -369,14 +368,13 @@ def add_comp_and_children(added, comp): if dep_var is None: if prefix: error( - "Could not find expression: " - "({0}){1} while adding " - "derivative".format(prefix, obj.dep_var), + "Could not find expression: " "({0}){1} while adding " "derivative".format( + prefix, obj.dep_var + ), ) else: error( - "Could not find expression: " - "{0} while adding derivative".format( + "Could not find expression: " "{0} while adding derivative".format( obj.dep_var, ), ) @@ -436,14 +434,10 @@ def save(self, basename=None): comp_names[comp] = comp_name states = [ - f"{obj.name}={obj.param.repr(include_name=False)}," - for obj in comp.ode_objects - if isinstance(obj, State) + f"{obj.name}={obj.param.repr(include_name=False)}," for obj in comp.ode_objects if isinstance(obj, State) ] parameters = [ - f"{obj.name}={obj.param.repr(include_name=False)}," - for obj in comp.ode_objects - if isinstance(obj, Parameter) + f"{obj.name}={obj.param.repr(include_name=False)}," for obj in comp.ode_objects if isinstance(obj, Parameter) ] if states: lines.append("") @@ -501,9 +495,7 @@ def save(self, basename=None): # If comment is component comment if str(obj) == comp_comment: lines.append("") - comp_name = ( - comp_names[comp] if comp_names[comp] else f'"{basename}"' - ) + comp_name = comp_names[comp] if comp_names[comp] else f'"{basename}"' lines.append(f"expressions({comp_name})") # Just add the comment @@ -547,10 +539,7 @@ def register_ode_object(self, obj, comp, dependent=None): # If duplicated object is an ODE Parameter and the added object is # either a State or a Parameter we replace the Parameter. elif ( - isinstance(dup_obj, Parameter) - and dup_comp == self - and comp != self - and isinstance(obj, (State, Parameter)) + isinstance(dup_obj, Parameter) and dup_comp == self and comp != self and isinstance(obj, (State, Parameter)) ): timer = Timer("Replace objects") # noqa: F841 @@ -590,14 +579,9 @@ def register_ode_object(self, obj, comp, dependent=None): # If duplicated object is an ODE Parameter and the added # object is an Intermediate we raise an error. - elif ( - isinstance(dup_obj, Parameter) - and dup_comp == self - and isinstance(obj, Expression) - ): + elif isinstance(dup_obj, Parameter) and dup_comp == self and isinstance(obj, Expression): error( - "Cannot replace an ODE parameter with an Expression, " - "only with Parameters and States.", + "Cannot replace an ODE parameter with an Expression, " "only with Parameters and States.", ) # If State, Parameter or DerivativeExpression we always raise an error @@ -617,8 +601,7 @@ def register_ode_object(self, obj, comp, dependent=None): for oo in [dup_obj, obj] ): error( - "Cannot register {0}. A {1} with name '{2}' is " - "already registered in this ODE.".format( + "Cannot register {0}. A {1} with name '{2}' is " "already registered in this ODE.".format( type(obj).__name__, type(dup_obj).__name__, dup_obj.name, @@ -680,8 +663,7 @@ def register_ode_object(self, obj, comp, dependent=None): if dep_obj is None: error( - "The symbol '{0}' is not declared within the '{1}' " - "ODE.".format(sym, self.name), + "The symbol '{0}' is not declared within the '{1}' " "ODE.".format(sym, self.name), ) # Store object dependencies @@ -693,8 +675,7 @@ def register_ode_object(self, obj, comp, dependent=None): if isinstance(obj, StateSolution) and self.object_used_in.get(obj.state): used_in = self.object_used_in.get(obj.state) error( - "A state solution cannot have been used in " - "any previous expressions. {0} is used in: {1}".format( + "A state solution cannot have been used in " "any previous expressions. {0} is used in: {1}".format( obj.state, used_in, ), @@ -891,9 +872,7 @@ def mass_matrix(self): self._mass_matrix = sp.Matrix( N, N, - lambda i, j: ( - 1 if i == j and isinstance(state_exprs[i], StateDerivative) else 0 - ), + lambda i, j: (1 if i == j and isinstance(state_exprs[i], StateDerivative) else 0), ) return self._mass_matrix @@ -906,9 +885,7 @@ def is_dae(self): if not self.is_complete: error("The ODE is not complete") - return any( - isinstance(expr, AlgebraicExpression) for expr in self.state_expressions - ) + return any(isinstance(expr, AlgebraicExpression) for expr in self.state_expressions) def finalize(self): """ @@ -984,9 +961,7 @@ def _replace_object(self, old_obj, replaced_obj, replace_dicts): # FIXME: Do not remove the dependencies # self.expression_dependencies[updated_expr] = \ # self.expression_dependencies.pop(expr) - self.expression_dependencies[replaced_expr] = self.expression_dependencies[ - old_expr - ] + self.expression_dependencies[replaced_expr] = self.expression_dependencies[old_expr] # Find the index of old expression and exchange it with updated old_comp = self.object_component[old_expr] @@ -1020,9 +995,7 @@ def _handle_expr_component(self, comp, expr): # We are shifting expression components elif self.all_expr_components_ordered[-1] != comp.name: # Finalize the last component we visited - self.all_components[ - self.all_expr_components_ordered[-1] - ].finalize_component() + self.all_components[self.all_expr_components_ordered[-1]].finalize_component() # Append this component self.all_expr_components_ordered.append(comp.name) @@ -1059,14 +1032,16 @@ def _expand_single_derivative(self, comp, obj, der_expr, replace_dict, dependent if not isinstance(der_expr.args[0], AppliedUndef): error( - "Can only register Derivatives of allready registered " - "Expressions. Got: {0}".format(sympycode(der_expr.args[0])), + "Can only register Derivatives of allready registered " "Expressions. Got: {0}".format( + sympycode(der_expr.args[0]) + ), ) if not isinstance(der_expr.args[1], (AppliedUndef, sp.Symbol)): error( - "Can only register Derivatives with a single dependent " - "variabe. Got: {0}".format(sympycode(der_expr.args[1])), + "Can only register Derivatives with a single dependent " "variabe. Got: {0}".format( + sympycode(der_expr.args[1]) + ), ) # Get the expr and dependent variable objects @@ -1093,8 +1068,7 @@ def _expand_single_derivative(self, comp, obj, der_expr, replace_dict, dependent if not isinstance(expr_obj, Expression): error( - "Can only differentiate expressions or states. Got {0} as " - "the derivative expression.".format(expr_obj), + "Can only differentiate expressions or states. Got {0} as " "the derivative expression.".format(expr_obj), ) # Expand derivative and see if it is trivial @@ -1107,10 +1081,7 @@ def _expand_single_derivative(self, comp, obj, der_expr, replace_dict, dependent or ( isinstance(der_result, (sp.Mul, sp.Pow, sp.Add)) and len(der_result.args) == 2 - and all( - isinstance(arg, (sp.Number, sp.Symbol, AppliedUndef)) - for arg in der_result.args - ) + and all(isinstance(arg, (sp.Number, sp.Symbol, AppliedUndef)) for arg in der_result.args) ) ): replace_dict[der_expr] = der_result diff --git a/gotran/model/odecomponent.py b/gotran/model/odecomponent.py index 3abd783..1d017b1 100644 --- a/gotran/model/odecomponent.py +++ b/gotran/model/odecomponent.py @@ -217,8 +217,7 @@ def add_states(self, *args, **kwargs): for arg in states: if not isinstance(arg, tuple) or len(arg) != 2: error( - "excpected tuple with lenght 2 with state name (str) " - "and init values as the args argument.", + "excpected tuple with lenght 2 with state name (str) " "and init values as the args argument.", ) state_name, init = arg @@ -266,8 +265,7 @@ def add_parameters(self, *args, **kwargs): for arg in params: if not isinstance(arg, tuple) or len(arg) != 2: error( - "excpected tuple with lenght 2 with parameter name (str) " - "and init values as the args argument.", + "excpected tuple with lenght 2 with parameter name (str) " "and init values as the args argument.", ) parameter_name, value = arg @@ -360,14 +358,12 @@ def add_state_solution(self, state, expr, dependent=None): if f"d{state.name}_dt" in self.ode_objects: error( - "Cannot registered a state solution for a state " - "that has a state derivative registered.", + "Cannot registered a state solution for a state " "that has a state derivative registered.", ) if f"alg_{state.name}_0" in self.ode_objects: error( - "Cannot registered a state solution for a state " - "that has an algebraic expression registered.", + "Cannot registered a state solution for a state " "that has an algebraic expression registered.", ) # Create a StateSolution in the present component @@ -463,14 +459,12 @@ def add_algebraic(self, state, expr, dependent=None): if f"d{state.name}_dt" in self.ode_objects: error( - "Cannot registered an algebraic expression for a state " - "that has a state derivative registered.", + "Cannot registered an algebraic expression for a state " "that has a state derivative registered.", ) if state.is_solved: error( - "Cannot registered an algebraic expression for a state " - "which is registered solved.", + "Cannot registered an algebraic expression for a state " "which is registered solved.", ) # Create an AlgebraicExpression in the present component @@ -491,9 +485,7 @@ def full_states(self): Return a list of all states in the component and its children that are not solved and determined by a state expression """ - return [ - expr.state for expr in self.state_expressions if not expr.state.is_solved - ] + return [expr.state for expr in self.state_expressions if not expr.state.is_solved] @property def full_state_vector(self): @@ -523,9 +515,7 @@ def state_expressions(self): """ Return a list of state expressions """ - return sorted( - (obj for obj in iter_objects(self, False, False, False, StateExpression)) - ) + return sorted((obj for obj in iter_objects(self, False, False, False, StateExpression))) @property def rate_expressions(self): @@ -614,9 +604,7 @@ def is_complete(self): """ True if the component and all its children are locally complete """ - return self.is_locally_complete and all( - child.is_complete for child in list(self.children.values()) - ) + return self.is_locally_complete and all(child.is_complete for child in list(self.children.values())) @property def is_locally_complete(self): @@ -624,11 +612,7 @@ def is_locally_complete(self): True if the number of non-solved states are the same as the number of registered state expressions """ - num_local_states = sum( - 1 - for obj in self.ode_objects - if isinstance(obj, State) and not obj.is_solved - ) + num_local_states = sum(1 for obj in self.ode_objects if isinstance(obj, State) and not obj.is_solved) return num_local_states == len(self._local_state_expressions) @@ -651,8 +635,7 @@ def __setattr__(self, name, value): isinstance(value, sp.Basic) and symbols_from_expr(value) ): debug( - "Not registering: {0} as attribut. It does not contain " - "any symbols or scalars.".format(name), + "Not registering: {0} as attribut. It does not contain " "any symbols or scalars.".format(name), ) # FIXME: Should we raise an error? @@ -730,18 +713,14 @@ def _expect_state(self, state, allow_state_solution=False, only_local_states=Fal if state is None: error(f"{name} is not registered in this ODE") - if only_local_states and not ( - state in self.states - or (state in self.intermediates and allow_state_solution) - ): + if only_local_states and not (state in self.states or (state in self.intermediates and allow_state_solution)): error(f"{name} is not registered in component {self.name}") check_arg(state, allowed, 0) if isinstance(state, State) and state.is_solved: error( - "Cannot registered a state expression for a state " - "which is registered solved.", + "Cannot registered a state expression for a state " "which is registered solved.", ) return state @@ -753,8 +732,7 @@ def _register_component_object(self, obj, dependent=None): if self._is_finalized: error( - "Cannot add {0} {1} to component {2} it is " - "already finalized.".format(obj.__class__.__name__, obj, self), + "Cannot add {0} {1} to component {2} it is " "already finalized.".format(obj.__class__.__name__, obj, self), ) self._check_reserved_wordings(obj) @@ -774,16 +752,16 @@ def _register_component_object(self, obj, dependent=None): if obj.state in self._local_state_expressions: error( - "A StateExpression for state {0} is already registered " - "in this component.".format(obj.state.name), + "A StateExpression for state {0} is already registered " "in this component.".format(obj.state.name), ) # Check that the state is registered in this component state_obj = self.ode_objects.get(obj.state.name) if not isinstance(state_obj, State): error( - "The state expression {0} defines state {1}, which is " - "not registered in the {2} component.".format(obj, obj.state, self), + "The state expression {0} defines state {1}, which is " "not registered in the {2} component.".format( + obj, obj.state, self + ), ) self._local_state_expressions[obj.state] = obj @@ -812,8 +790,9 @@ def _register_component_object(self, obj, dependent=None): def _check_reserved_wordings(self, obj): if obj.name in _all_keywords: error( - "Cannot register a {0} with a computer language " - "keyword name: {1}".format(obj.__class__.__name__, obj.name), + "Cannot register a {0} with a computer language " "keyword name: {1}".format( + obj.__class__.__name__, obj.name + ), ) # Check for reserved Expression wordings @@ -853,9 +832,7 @@ def _add_rates(self, states, rate_matrix): states = (states, states) # else tuple - elif len(states) != 2 and not all( - isinstance(list_of_states, list) for list_of_states in states - ): + elif len(states) != 2 and not all(isinstance(list_of_states, list) for list_of_states in states): error("expected a tuple of 2 lists with states as the " "states argument") # Check index arguments @@ -877,9 +854,7 @@ def _add_rates(self, states, rate_matrix): value = rate_matrix[i, j] # If 0 as rate - if (isinstance(value, scalars) and value == 0) or ( - isinstance(value, sp.Basic) and value.is_zero - ): + if (isinstance(value, scalars) and value == 0) or (isinstance(value, sp.Basic) and value.is_zero): continue if state_i == state_j: @@ -1013,8 +988,7 @@ def _finalize_markov_model(self): for (ind_from, ind_to), times in list(rate_check.items()): if times != 2: error( - "Only one rate between the states {0} and {1} was " - "registered, expected two.".format( + "Only one rate between the states {0} and {1} was " "registered, expected two.".format( states[ind_from], states[ind_to], ), diff --git a/gotran/model/odeobjects.py b/gotran/model/odeobjects.py index 1d08537..f0b20f2 100644 --- a/gotran/model/odeobjects.py +++ b/gotran/model/odeobjects.py @@ -84,10 +84,7 @@ def __init__(self, name, dependent=None): ODEObject.__dependent_counts[dependent._count] += 1 # FIXME: Do not hardcode the fractional increase - self._count = ( - dependent._count - + ODEObject.__dependent_counts[dependent._count] * 0.00001 - ) + self._count = dependent._count + ODEObject.__dependent_counts[dependent._count] * 0.00001 def __hash__(self): return id(self) @@ -182,10 +179,7 @@ def _recount(self, new_count=None, dependent=None): ODEObject.__dependent_counts[dependent._count] += 1 # FIXME: Do not hardcode the fractional increase - self._count = ( - dependent._count - + ODEObject.__dependent_counts[dependent._count] * 0.00001 - ) + self._count = dependent._count + ODEObject.__dependent_counts[dependent._count] * 0.00001 else: self._count = ODEObject.__count ODEObject.__count += 1 @@ -549,9 +543,7 @@ def __init__( if flatten and len(indices) > 1: indices = ( sum( - reduce(lambda i, j: i * j, shape[i + 1 :], 1) - * (index + index_offset) - for i, index in enumerate(indices) + reduce(lambda i, j: i * j, shape[i + 1 :], 1) * (index + index_offset) for i, index in enumerate(indices) ), ) else: diff --git a/gotran/model/utils.py b/gotran/model/utils.py index 4668203..9c40fa0 100644 --- a/gotran/model/utils.py +++ b/gotran/model/utils.py @@ -65,11 +65,7 @@ def ode_primitives(expr, time): for node in pt: # Collect AppliedUndefs which are functions of time - if ( - isinstance(node, AppliedUndef) - and len(node.args) == 1 - and node.args[0] == time - ): + if isinstance(node, AppliedUndef) and len(node.args) == 1 and node.args[0] == time: pt.skip() symbols.add(node) elif isinstance(node, Symbol): @@ -357,8 +353,7 @@ def __setitem__(self, states, expr): else: if not isinstance(states, tuple) or len(states) != 2: error( - "Expected a tuple of size 2 with states when " - "registering a single rate.", + "Expected a tuple of size 2 with states when " "registering a single rate.", ) # NOTE: the actuall item is set by the component while calling this diff --git a/gotran/scripts/cellml2gotran.py b/gotran/scripts/cellml2gotran.py index 8963541..c0e75d8 100644 --- a/gotran/scripts/cellml2gotran.py +++ b/gotran/scripts/cellml2gotran.py @@ -71,9 +71,7 @@ def list_parser(option, opt_str, value, parser, arg_list): parser.add_option( "-p", "--strip_parent_name", - help="If True strip the name from " - "the child component it contains the name of the " - "parent component.", + help="If True strip the name from " "the child component it contains the name of the " "parent component.", action="store_true", default=True, dest="strip_parent_name", diff --git a/gotran/scripts/gotran2c.py b/gotran/scripts/gotran2c.py index 21ffa91..47f6a0d 100755 --- a/gotran/scripts/gotran2c.py +++ b/gotran/scripts/gotran2c.py @@ -61,21 +61,14 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), system_headers=Param( True, - description="If true system " - "headers needed to compile moudle is " - "included.", + description="If true system " "headers needed to compile moudle is " "included.", ), output=Param("", description="Specify output file name"), - **dict( - (name, param) - for name, param in list(generation_params.items()) - if name not in ["class_code"] - ), + **dict((name, param) for name, param in list(generation_params.items()) if name not in ["class_code"]), ) params.parse_args(usage="usage: %prog FILE [options]") # sys.argv[2:]) diff --git a/gotran/scripts/gotran2cpp.py b/gotran/scripts/gotran2cpp.py index 028d88d..88a9f84 100644 --- a/gotran/scripts/gotran2cpp.py +++ b/gotran/scripts/gotran2cpp.py @@ -62,15 +62,12 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), output=Param("", description="Specify output file name"), system_headers=Param( True, - description="If true system " - "headers needed to compile moudle is " - "included.", + description="If true system " "headers needed to compile moudle is " "included.", ), **generation_params, ) diff --git a/gotran/scripts/gotran2cuda.py b/gotran/scripts/gotran2cuda.py index 1eb7165..a4d7d08 100644 --- a/gotran/scripts/gotran2cuda.py +++ b/gotran/scripts/gotran2cuda.py @@ -51,21 +51,14 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), system_headers=Param( True, - description="If true system " - "headers needed to compile moudle is " - "included.", + description="If true system " "headers needed to compile moudle is " "included.", ), output=Param("", description="Specify output file name"), - **dict( - (name, param) - for name, param in list(generation_params.items()) - if name not in ["class_code"] - ), + **dict((name, param) for name, param in list(generation_params.items()) if name not in ["class_code"]), ) params.parse_args(usage="usage: %prog FILE [options]") # sys.argv[2:]) diff --git a/gotran/scripts/gotran2dolfin.py b/gotran/scripts/gotran2dolfin.py index 80e2125..be30c8b 100644 --- a/gotran/scripts/gotran2dolfin.py +++ b/gotran/scripts/gotran2dolfin.py @@ -33,8 +33,7 @@ def gotran2dolfin(filename, params): f.write(code_gen.init_states_code(ode) + "\n\n") f.write(code_gen.init_parameters_code(ode) + "\n\n") f.write( - code_gen.function_code(rhs_expressions(ode, params=code_gen.params.code)) - + "\n", + code_gen.function_code(rhs_expressions(ode, params=code_gen.params.code)) + "\n", ) diff --git a/gotran/scripts/gotran2julia.py b/gotran/scripts/gotran2julia.py index b4369a1..7272de0 100644 --- a/gotran/scripts/gotran2julia.py +++ b/gotran/scripts/gotran2julia.py @@ -55,8 +55,7 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), output=Param("", description="Specify output file name"), **JuliaCodeGenerator.default_parameters(), diff --git a/gotran/scripts/gotran2latex.py b/gotran/scripts/gotran2latex.py index 6b04a1e..6bab1f2 100755 --- a/gotran/scripts/gotran2latex.py +++ b/gotran/scripts/gotran2latex.py @@ -30,9 +30,7 @@ def gotran2latex(filename, params): if not ode.is_complete: raise Exception("Incomplete ODE") - params.output = ( - params.output or Path(filename).with_name(f"{ode.name}.tex").as_posix() - ) + params.output = params.output or Path(filename).with_name(f"{ode.name}.tex").as_posix() # Create a gotran -> LaTeX document generator gen = LatexCodeGenerator(ode, params) @@ -49,8 +47,7 @@ def main(): params = ParameterDict( sympy_contraction=Param( True, - description="If True sympy contraction" - " will be used, turning (V-3)/2 into V/2-3/2", + description="If True sympy contraction" " will be used, turning (V-3)/2 into V/2-3/2", ), **params, ) diff --git a/gotran/scripts/gotran2matlab.py b/gotran/scripts/gotran2matlab.py index 06a1e62..c0c6521 100755 --- a/gotran/scripts/gotran2matlab.py +++ b/gotran/scripts/gotran2matlab.py @@ -44,8 +44,7 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), output=Param("", description="Specify output file name"), **generation_params, diff --git a/gotran/scripts/gotran2md.py b/gotran/scripts/gotran2md.py index 9351463..8b75fa3 100644 --- a/gotran/scripts/gotran2md.py +++ b/gotran/scripts/gotran2md.py @@ -51,17 +51,11 @@ def gotran2md(filename): ) expr = "\n\n".join( - [ - f"$$\n{p._repr_latex_name()} = {p._repr_latex_expr()}\n$$" - for p in ode.intermediates - ], + [f"$$\n{p._repr_latex_name()} = {p._repr_latex_expr()}\n$$" for p in ode.intermediates], ) state_expr = "\n\n".join( - [ - f"$$\n{p._repr_latex_name()} = {p._repr_latex_expr()}\n$$" - for p in ode.state_expressions - ], + [f"$$\n{p._repr_latex_name()} = {p._repr_latex_expr()}\n$$" for p in ode.state_expressions], ) mdname = filename.with_suffix(".md") diff --git a/gotran/scripts/gotran2opencl.py b/gotran/scripts/gotran2opencl.py index d95e32c..29d65c5 100644 --- a/gotran/scripts/gotran2opencl.py +++ b/gotran/scripts/gotran2opencl.py @@ -47,21 +47,14 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), system_headers=Param( True, - description="If true system " - "headers needed to compile moudle is " - "included.", + description="If true system " "headers needed to compile moudle is " "included.", ), output=Param("", description="Specify output file name"), - **dict( - (name, param) - for name, param in list(generation_params.items()) - if name not in ["class_code"] - ), + **dict((name, param) for name, param in list(generation_params.items()) if name not in ["class_code"]), ) params.parse_args(usage="usage: %prog FILE [options]") # sys.argv[2:]) @@ -73,8 +66,7 @@ def main(): file_name = sys.argv[1] info( - "Note: The OpenCL support in gotran is a work in progress. " - "The CUDA generator is recommended for NVIDIA GPUs.", + "Note: The OpenCL support in gotran is a work in progress. " "The CUDA generator is recommended for NVIDIA GPUs.", ) gotran2opencl(file_name, params) diff --git a/gotran/scripts/gotran2py.py b/gotran/scripts/gotran2py.py index 7b7d4cd..6a26ddb 100644 --- a/gotran/scripts/gotran2py.py +++ b/gotran/scripts/gotran2py.py @@ -66,8 +66,7 @@ def main(): params = ParameterDict( list_timings=Param( False, - description="If true timings for reading " - "and evaluating the model is listed.", + description="If true timings for reading " "and evaluating the model is listed.", ), output=Param("", description="Specify output file name"), import_inside_functions=Param( diff --git a/gotran/scripts/gotranprobe.py b/gotran/scripts/gotranprobe.py index 864bc12..cc5a41a 100644 --- a/gotran/scripts/gotranprobe.py +++ b/gotran/scripts/gotranprobe.py @@ -74,8 +74,7 @@ def gotranprobe(filename, params): dep.name for interm in comp.intermediates for dep in ode.expression_dependencies[interm] - if not isinstance(dep, Comment) - and dep not in comp.full_states + comp.parameters + if not isinstance(dep, Comment) and dep not in comp.full_states + comp.parameters ), ), ) diff --git a/gotran/scripts/gotranrun.py b/gotran/scripts/gotranrun.py index aa1ac7b..446062d 100644 --- a/gotran/scripts/gotranrun.py +++ b/gotran/scripts/gotranrun.py @@ -72,8 +72,7 @@ def gotranrun(filename, params): arguments = dict() for arg_name, arg_value in [ - (model_arguments[i * 2], model_arguments[i * 2 + 1]) - for i in range(int(len(model_arguments) / 2)) + (model_arguments[i * 2], model_arguments[i * 2 + 1]) for i in range(int(len(model_arguments) / 2)) ]: arguments[arg_name] = arg_value @@ -82,8 +81,7 @@ def gotranrun(filename, params): # Check for DAE if ode.is_dae: error( - "Can only integrate pure ODEs. {0} includes algebraic states " - "and is hence a DAE.".format(ode.name), + "Can only integrate pure ODEs. {0} includes algebraic states " "and is hence a DAE.".format(ode.name), ) # Get monitored and plot states @@ -93,11 +91,7 @@ def gotranrun(filename, params): x_name = params.plot_x state_names = [state.name for state in ode.full_states] - monitored_plot = [ - plot_states.pop(plot_states.index(name)) - for name in plot_states[:] - if name not in state_names - ] + monitored_plot = [plot_states.pop(plot_states.index(name)) for name in plot_states[:] if name not in state_names] monitored = [] all_monitored_names = [] @@ -114,8 +108,7 @@ def gotranrun(filename, params): # Check x_name if x_name not in ["time"] + monitored + state_names: error( - "Expected plot_x to be either 'time' or one of the plotable " - "variables, got {}".format(x_name), + "Expected plot_x to be either 'time' or one of the plotable " "variables, got {}".format(x_name), ) # Logic if x_name is not 'time' as we then need to add the name to @@ -145,15 +138,13 @@ def gotranrun(filename, params): user_params = dict() for param_name, param_value in [ - (parameter_values[i * 2], parameter_values[i * 2 + 1]) - for i in range(int(len(parameter_values) / 2)) + (parameter_values[i * 2], parameter_values[i * 2 + 1]) for i in range(int(len(parameter_values) / 2)) ]: user_params[param_name] = float(param_value) user_ic = dict() for state_name, state_value in [ - (init_conditions[i * 2], init_conditions[i * 2 + 1]) - for i in range(int(len(init_conditions) / 2)) + (init_conditions[i * 2], init_conditions[i * 2 + 1]) for i in range(int(len(init_conditions) / 2)) ]: user_ic[state_name] = float(state_value) @@ -182,10 +173,7 @@ def gotranrun(filename, params): y0 = result.x print( "Found stead state:", - ", ".join( - f"{state.name}: {value:e}" - for value, state in zip(y0, ode.full_states) - ), + ", ".join(f"{state.name}: {value:e}" for value, state in zip(y0, ode.full_states)), ) else: warning(result.message) @@ -249,9 +237,7 @@ def gotranrun(filename, params): if params.save_results: save_results[ind, 0] = time save_results[ind, 1 : len(state_names) + 1] = res - save_results[ind, len(state_names) + 1 :] = monitored_get_values[ - all_monitor_inds - ] + save_results[ind, len(state_names) + 1 :] = monitored_get_values[all_monitor_inds] # Save data if params.save_results: @@ -290,11 +276,7 @@ def gotranrun(filename, params): # line_styles = cycle([c+s for s in ["-", "--", "-.", ":"] # for c in plt.rcParams["axes.color_cycle"]]) line_styles = cycle( - [ - c + s - for s in ["-", "--", "-.", ":"] - for c in ["b", "g", "r", "c", "m", "y", "k"] - ], + [c + s for s in ["-", "--", "-.", ":"] for c in ["b", "g", "r", "c", "m", "y", "k"]], ) plotted_items = 0 @@ -335,8 +317,7 @@ def main(): steady_state = ParameterDict( solve=Param( False, - description="If true scipy.optimize.root is used " - "to find a steady state for a given parameters.", + description="If true scipy.optimize.root is used " "to find a steady state for a given parameters.", ), method=OptionParam( "hybr", @@ -371,8 +352,7 @@ def main(): plot_y=Param(["V"], description="States or monitored to plot on the y axis."), plot_x=Param( "time", - description="Values used for the x axis. Can be time " - "and any valid plot_y variable.", + description="Values used for the x axis. Can be time " "and any valid plot_y variable.", ), model_arguments=Param([""], description="Set model arguments of the model"), code=code_params, @@ -382,8 +362,7 @@ def main(): ), basename=Param( "results", - description="The basename of the results " - "file if the 'save_results' options is True.", + description="The basename of the results " "file if the 'save_results' options is True.", ), ) diff --git a/pyproject.toml b/pyproject.toml index a8d5877..2d96deb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,7 @@ exclude = [ ] # Same as Black. -line-length = 110 +line-length = 125 # Assume Python 3.10. target-version = "py310" diff --git a/sandbox/cellml/test_cellml.py b/sandbox/cellml/test_cellml.py index 068aa4c..8d1ffe1 100644 --- a/sandbox/cellml/test_cellml.py +++ b/sandbox/cellml/test_cellml.py @@ -64,9 +64,7 @@ def get_encapsulation(elements): for group in cellml.getiterator(cellml_namespace + "group"): children = group.getchildren() - if ( - children and children[0].attrib.get("relationship") == "containment" - ): # encapsulation + if children and children[0].attrib.get("relationship") == "containment": # encapsulation encapsulations = get_encapsulation(children[1:]) si_unit_map = { @@ -219,13 +217,9 @@ def parse_component(comp): if state_variable is not None and state_variable not in state_variables: state_variables[state_variable] = derivative - state_variables = OrderedDict( - (state, variables.pop(state, None)) for state in state_variables - ) + state_variables = OrderedDict((state, variables.pop(state, None)) for state in state_variables) - parameters = OrderedDict( - (name, value) for name, value in list(variables.items()) if value is not None - ) + parameters = OrderedDict((name, value) for name, value in list(variables.items()) if value is not None) component["states"] = state_variables component["parameters"] = parameters diff --git a/sandbox/gpu/cossplotter.py b/sandbox/gpu/cossplotter.py index e2a9755..9bf92a9 100644 --- a/sandbox/gpu/cossplotter.py +++ b/sandbox/gpu/cossplotter.py @@ -73,10 +73,7 @@ def plotResults(_file, plotTypes=None, get_stored_fstates=False): ) figures = list() for pType in plotTypes: - if ( - pType["x"]["type"] not in validPlotTypes - or pType["y"]["type"] not in validPlotTypes - ): + if pType["x"]["type"] not in validPlotTypes or pType["y"]["type"] not in validPlotTypes: print(str(pType) + " plot not yet implemented") else: figures.append(pType) diff --git a/sandbox/gpu/cosstester.py b/sandbox/gpu/cosstester.py index 9e5950a..185a596 100644 --- a/sandbox/gpu/cosstester.py +++ b/sandbox/gpu/cosstester.py @@ -79,11 +79,7 @@ def __init__(self, **kwargs): self.num_nodes, self.stored_field_states, ) - initial_field_params = [ - param.init - for param in self.ode.parameters - if param.name in self.field_parameters - ] + initial_field_params = [param.init for param in self.ode.parameters if param.name in self.field_parameters] if self.field_parameter_values_getter_fn is None: self.field_parameter_values = None else: @@ -127,9 +123,7 @@ def required_parameters(): @staticmethod def parameter_iterability_map(): - return { - k: v == [""] for k, v in list(COSSTestCase.default_parameters().items()) - } + return {k: v == [""] for k, v in list(COSSTestCase.default_parameters().items())} @staticmethod def check_kwargs(**kwargs): @@ -907,8 +901,7 @@ def saveResults(title, results, directory): f.write( "{0}\n".format( "\n".join( - " ".join("{0}".format(fs) for fs in field_states) - for field_states in tc.stored_field_states + " ".join("{0}".format(fs) for fs in field_states) for field_states in tc.stored_field_states ), ), ) @@ -1028,10 +1021,7 @@ def getDataFromFile(_file, get_stored_fstates=False): list(map(float, fpv)) for fpv in map( ast.literal_eval, - [ - "[" + s.replace("]", " ").replace("[", " ") + "]" - for s in f.readline().rstrip().split("] [") - ], + ["[" + s.replace("]", " ").replace("[", " ") + "]" for s in f.readline().rstrip().split("] [")], ) ] datum["field_state_values"] = list( @@ -1127,10 +1117,7 @@ def plotResults(_file, plotTypes=None, get_stored_fstates=False): ) figures = list() for pType in plotTypes: - if ( - pType["x"]["type"] not in validPlotTypes - or pType["y"]["type"] not in validPlotTypes - ): + if pType["x"]["type"] not in validPlotTypes or pType["y"]["type"] not in validPlotTypes: print(str(pType) + " plot not yet implemented") else: figures.append(pType) diff --git a/sandbox/gpu/cuda_vs_goss.py b/sandbox/gpu/cuda_vs_goss.py index 91efb59..3be0478 100644 --- a/sandbox/gpu/cuda_vs_goss.py +++ b/sandbox/gpu/cuda_vs_goss.py @@ -169,34 +169,22 @@ def run_gpu( ) print( "PERCENT REL DIFF SINGLE > 0.1%", - (np.absolute((goss_result - gpu_result_single) / goss_result) > 1.0e-3).sum() - * 1.0 - / len(goss_result) - * 100, + (np.absolute((goss_result - gpu_result_single) / goss_result) > 1.0e-3).sum() * 1.0 / len(goss_result) * 100, "%", ) print( "PERCENT REL DIFF SINGLE > 1%", - (np.absolute((goss_result - gpu_result_single) / goss_result) > 1.0e-2).sum() - * 1.0 - / len(goss_result) - * 100, + (np.absolute((goss_result - gpu_result_single) / goss_result) > 1.0e-2).sum() * 1.0 / len(goss_result) * 100, "%", ) print( "PERCENT REL DIFF SINGLE > 2%", - (np.absolute((goss_result - gpu_result_single) / goss_result) > 2.0e-2).sum() - * 1.0 - / len(goss_result) - * 100, + (np.absolute((goss_result - gpu_result_single) / goss_result) > 2.0e-2).sum() * 1.0 / len(goss_result) * 100, "%", ) print( "PERCENT REL DIFF SINGLE > 3%", - (np.absolute((goss_result - gpu_result_single) / goss_result) > 3.0e-2).sum() - * 1.0 - / len(goss_result) - * 100, + (np.absolute((goss_result - gpu_result_single) / goss_result) > 3.0e-2).sum() * 1.0 / len(goss_result) * 100, "%", ) print() diff --git a/sandbox/gpu/cudaodesystemsolver.py b/sandbox/gpu/cudaodesystemsolver.py index bd1be68..2b6a85f 100644 --- a/sandbox/gpu/cudaodesystemsolver.py +++ b/sandbox/gpu/cudaodesystemsolver.py @@ -65,9 +65,7 @@ def init_cuda(self, params=None): nvcc = self.params.nvcc or "nvcc" gpu_arch = self.params.gpu_arch if self.params.gpu_arch else None gpu_code = self.params.gpu_code if self.params.gpu_code else None - cuda_cache_dir = ( - self.params.cuda_cache_dir if self.params.cuda_cache_dir else None - ) + cuda_cache_dir = self.params.cuda_cache_dir if self.params.cuda_cache_dir else None nvcc_options = self.params.nvcc_options # FIXME: modelparameters needs a ListParam if nvcc_options is not None and len(nvcc_options) > 0 and nvcc_options[0] == "": @@ -87,9 +85,7 @@ def init_cuda(self, params=None): self.ctx.set_cache_config(cuda.func_cache.PREFER_L1) - float_t = ( - "float64" if self.params.code.float_precision == "double" else "float32" - ) + float_t = "float64" if self.params.code.float_precision == "double" else "float32" float_sz = np.dtype(float_t).itemsize # Allocate and initialise states @@ -244,8 +240,7 @@ def _get_block(self): def _get_grid(self): block_size = self.params.block_size grid = ( - self._num_nodes // block_size - + (0 if self._num_nodes % block_size == 0 else 1), + self._num_nodes // block_size + (0 if self._num_nodes % block_size == 0 else 1), 1, ) return grid @@ -256,13 +251,7 @@ def _get_code(self): def _dump_kernel_code(self): if not self.is_ready(): return "" - fname = ( - "tmp" - + os.path.sep - + "kernel-" - + hashlib.sha1(self._get_code()).hexdigest() - + ".cu" - ) + fname = "tmp" + os.path.sep + "kernel-" + hashlib.sha1(self._get_code()).hexdigest() + ".cu" with open(fname, "w") as f: f.write(self._get_code()) return fname @@ -299,11 +288,7 @@ def __init__(self, num_nodes, ode, init_field_parameters=None, params=None): self.get_field_states() # FIXME: modelparameters needs a ListParam - if ( - init_field_parameters is not None - and len(p_field_parameters) > 0 - and p_field_states[0] != "" - ): + if init_field_parameters is not None and len(p_field_parameters) > 0 and p_field_states[0] != "": self.set_field_parameters(init_field_parameters) self.ode_substeps = self.params.ode_substeps @@ -328,8 +313,7 @@ def default_parameters(): ode_substeps=ScalarParam( 1, ge=1, - description="Number of ODE steps to compute per " - "forward function call", + description="Number of ODE steps to compute per " "forward function call", ), nvcc=Param("nvcc", description="Command to run nvcc compiler"), gpu_arch=TypelessParam( diff --git a/sandbox/gpu/test_gpu.py b/sandbox/gpu/test_gpu.py index 7f446b7..6ac4682 100644 --- a/sandbox/gpu/test_gpu.py +++ b/sandbox/gpu/test_gpu.py @@ -62,14 +62,7 @@ .replace("parameters[", "parameters[param_offset+") ) -gpu_code = ( - '#include "math.h"\n\n' - + init_state_code - + "\n\n" - + init_param_code - + "\n\n" - + rhs_code -) +gpu_code = '#include "math.h"\n\n' + init_state_code + "\n\n" + init_param_code + "\n\n" + rhs_code print(gpu_code) diff --git a/sandbox/gpu/testeverything.py b/sandbox/gpu/testeverything.py index 3232c32..847699a 100644 --- a/sandbox/gpu/testeverything.py +++ b/sandbox/gpu/testeverything.py @@ -45,11 +45,7 @@ def store_field_states_fn(field_states): def get_g_to_field_parameter_values(num_nodes, float_precision): - return ( - 0.294 - * np.arange(0, num_nodes, dtype=get_dtype_str(float_precision))[::-1] - / (num_nodes - 1.0) - ) + return 0.294 * np.arange(0, num_nodes, dtype=get_dtype_str(float_precision))[::-1] / (num_nodes - 1.0) tentusscher_fname = "tentusscher_panfilov_2006_M_cell.ode" @@ -84,16 +80,10 @@ def __init__( self.t0 = t0 self.solver = solver self.field_states = field_states - self.field_states_fn = ( - field_states_getter_fn(num_nodes) - if field_states_getter_fn is not None - else None - ) + self.field_states_fn = field_states_getter_fn(num_nodes) if field_states_getter_fn is not None else None self.field_parameters = field_parameters self.field_parameter_values = ( - field_parameter_values_getter_fn(num_nodes, double) - if field_parameter_values_getter_fn is not None - else None + field_parameter_values_getter_fn(num_nodes, double) if field_parameter_values_getter_fn is not None else None ) self.block_size = block_size self.double = double @@ -349,10 +339,7 @@ def testUpdateStates( print("Running UPDATE HOST/FIELD STATES tests...") - names = [ - "host={0}, field={1}".format(h, f) - for h, f in it.product(update_host_states, update_field_states) - ] + names = ["host={0}, field={1}".format(h, f) for h, f in it.product(update_host_states, update_field_states)] all_field_states, runtimes, errors = list(zip(*runTests(testcases))) diff --git a/tests/codegeneration/test_codegeneration.py b/tests/codegeneration/test_codegeneration.py index b3e48a4..a9e9c62 100644 --- a/tests/codegeneration/test_codegeneration.py +++ b/tests/codegeneration/test_codegeneration.py @@ -41,11 +41,7 @@ def get_all_options_no_default(d, key): def get_indexed(comp, name): gotran.codecomponent.check_arg(comp, gotran.CodeComponent) - return [ - expr - for expr in comp.body_expressions - if isinstance(expr, gotran.IndexedExpression) and expr.basename == name - ] + return [expr for expr in comp.body_expressions if isinstance(expr, gotran.IndexedExpression) and expr.basename == name] @pytest.fixture(scope="session") @@ -166,11 +162,7 @@ def _test_codegeneration( assert rhs_norm < eps # Only evaluate jacobian if using full body_optimization and body repr is reused_array - if ( - body_optimize != "numerals_symbols" - and body_repr != "reused_array" - and param_repr == "named" - ): + if body_optimize != "numerals_symbols" and body_repr != "reused_array" and param_repr == "named": return jac_comp = gotran.jacobian_expressions(ode, params=code_params) diff --git a/tests/input/test_cellml.py b/tests/input/test_cellml.py index 5ecc563..3a8aaba 100644 --- a/tests/input/test_cellml.py +++ b/tests/input/test_cellml.py @@ -57,17 +57,11 @@ terkildsen_niederer_crampin_hunter_smith_2008="NameError: name 'I_Na' is not defined", niederer_hunter_smith_2006="NameError: name 'J_TRPN' is not defined", winslow_rice_jafri_marban_ororke_1999=( - "self.assertTrue(rel_diff<6e-3), " - "AssertionError: False is not true, " - "Rel diff: 0.3952307989374001" + "self.assertTrue(rel_diff<6e-3), " "AssertionError: False is not true, " "Rel diff: 0.3952307989374001" ), maleckar_greenstein_trayanova_giles_2009=("assert 0.5240054057725506 < 0.006"), ) -cellml_models = [ - model - for model in glob.glob(_here.joinpath("*.cellml").as_posix()) - if Path(model).stem not in skip -] +cellml_models = [model for model in glob.glob(_here.joinpath("*.cellml").as_posix()) if Path(model).stem not in skip] # Copy of default parameters diff --git a/tests/input/test_compile_module.py b/tests/input/test_compile_module.py index 9b7352d..94dc39f 100644 --- a/tests/input/test_compile_module.py +++ b/tests/input/test_compile_module.py @@ -86,11 +86,7 @@ def test_compile_monitored(ode, generation): monitor_c = c_module.monitor(states, 0, parameters) assert np.isclose(monitor_python, monitor_c).all() - assert ( - python_module.monitor_indices("i_CaL") - == c_module.monitor_indices("i_CaL") - == monitored.index("i_CaL") - ) + assert python_module.monitor_indices("i_CaL") == c_module.monitor_indices("i_CaL") == monitored.index("i_CaL") @require_cppyy diff --git a/tests/model/test_ode.py b/tests/model/test_ode.py index 2aad499..b8fbacd 100644 --- a/tests/model/test_ode.py +++ b/tests/model/test_ode.py @@ -104,9 +104,7 @@ def test_creation(): assert ode.object_used_in[tmp3] == {ode.present_ode_objects["dl_dt"]} for sym in symbols_from_expr(tmp3.expr, include_derivatives=True): - assert ( - ode.present_ode_objects[sympycode(sym)] in ode.expression_dependencies[tmp3] - ) + assert ode.present_ode_objects[sympycode(sym)] in ode.expression_dependencies[tmp3] # Add another component to test rates bada = ode("bada") @@ -344,21 +342,7 @@ def test_subode(): i_p_Ca = ode("Calcium pump current").i_p_Ca # Membrane potential derivative - mem.dV_dt = ( - -i_Ks - - i_p_K - - i_Na - - i_K1 - - i_p_Ca - - i_b_Ca - - i_NaK - - i_CaL - - i_Kr - - ode.i_Stim - - i_NaCa - - i_b_Na - - i_to - ) + mem.dV_dt = -i_Ks - i_p_K - i_Na - i_K1 - i_p_Ca - i_b_Ca - i_NaK - i_CaL - i_Kr - ode.i_Stim - i_NaCa - i_b_Na - i_to # Finalize ODE ode.finalize() diff --git a/tests/model/test_odeobjects.py b/tests/model/test_odeobjects.py index fed1ecb..8d4e303 100644 --- a/tests/model/test_odeobjects.py +++ b/tests/model/test_odeobjects.py @@ -24,16 +24,12 @@ def test_odeobjects(): # breakpoint() assert cm.value.args == ( - "expected 'str' (got '45' which " - "is 'int') as the first argument while instantiating" - " 'ODEObject'", + "expected 'str' (got '45' which " "is 'int') as the first argument while instantiating" " 'ODEObject'", ) with pytest.raises(gotran.GotranException) as cm: gotran.ODEObject("_jada") - assert cm.value.args == ( - "No ODEObject names can start " "with an underscore: '_jada'", - ) + assert cm.value.args == ("No ODEObject names can start " "with an underscore: '_jada'",) obj0 = gotran.ODEObject("jada bada") assert str(obj0) == "jada bada" @@ -52,16 +48,11 @@ def test_odevalueobjects(): with pytest.raises(gotran.GotranException) as cm: gotran.ODEValueObject("_jada", 45) - assert cm.value.args == ( - "No ODEObject names can start " "with an underscore: '_jada'", - ) + assert cm.value.args == ("No ODEObject names can start " "with an underscore: '_jada'",) obj = gotran.ODEValueObject("bada", 45) - assert ( - Symbol(obj.name, real=True, imaginary=False, commutative=True, hermitian=True) - == obj.sym - ) + assert Symbol(obj.name, real=True, imaginary=False, commutative=True, hermitian=True) == obj.sym assert 45 == obj.value @@ -73,9 +64,7 @@ def test_state(): with pytest.raises(gotran.GotranException) as cm: gotran.State("_jada", 45, t) - assert cm.value.args == ( - "No ODEObject names can start " "with an underscore: '_jada'", - ) + assert cm.value.args == ("No ODEObject names can start " "with an underscore: '_jada'",) s = gotran.State("s", 45.0, t) a = gotran.State("a", 56.0, t) @@ -106,9 +95,7 @@ def test_param(): with pytest.raises(gotran.GotranException) as cm: gotran.Parameter("_jada", 45) - assert cm.value.args == ( - "No ODEObject names can start " "with an underscore: '_jada'", - ) + assert cm.value.args == ("No ODEObject names can start " "with an underscore: '_jada'",) s = gotran.Parameter("s", 45.0)