diff --git a/dymos/examples/oscillator/oscillator_ode.py b/dymos/examples/oscillator/oscillator_ode.py index dac0fcd7d..a46aac647 100644 --- a/dymos/examples/oscillator/oscillator_ode.py +++ b/dymos/examples/oscillator/oscillator_ode.py @@ -36,3 +36,37 @@ def compute(self, inputs, outputs): f_damper = -c * v outputs['v_dot'] = (f_spring + f_damper) / m + + +class OscillatorVectorODE(om.ExplicitComponent): + """ + A Dymos ODE for a damped harmonic oscillator with vectorized states. + """ + + def initialize(self): + self.options.declare('num_nodes', types=int) + self.options.declare('static_params', types=bool) + + def setup(self): + nn = self.options['num_nodes'] + static = self.options['static_params'] + + # Inputs + self.add_input('x', val=np.ones((nn, 2))) + if static: + self.add_input('A', val=np.ones((2, 2))) + else: + self.add_input('A', val=np.ones((nn, 2, 2))) + + # Output + self.add_output('x_dot', val=np.zeros((nn, 2))) + + self.declare_partials(of='*', wrt='*', method='fd') + + def compute(self, inputs, outputs): + A = inputs['A'] + x = inputs['x'] + + static = self.options['static_params'] + ein_sequence = 'jk, ik->ij' if static else 'ijk, ik->ij' + outputs['x_dot'] = np.einsum(ein_sequence, A, x) diff --git a/dymos/examples/oscillator/test/test_oscillator_vector_states.py b/dymos/examples/oscillator/test/test_oscillator_vector_states.py new file mode 100644 index 000000000..924d19afa --- /dev/null +++ b/dymos/examples/oscillator/test/test_oscillator_vector_states.py @@ -0,0 +1,99 @@ +import unittest +from openmdao.utils.testing_utils import use_tempdirs +from openmdao.utils.assert_utils import assert_near_equal + +import openmdao.api as om +import dymos as dm +import numpy as np + + +# @use_tempdirs +class TestDocOscillator(unittest.TestCase): + + def test_matrix_param(self): + + from dymos.examples.oscillator.oscillator_ode import OscillatorVectorODE + + # Instantiate an OpenMDAO Problem instance. + prob = om.Problem() + prob.driver = om.ScipyOptimizeDriver() + prob.driver.options["optimizer"] = 'SLSQP' + + static_params = False + + t = dm.Radau(num_segments=2, order=3) + phase = dm.Phase(ode_class=OscillatorVectorODE, transcription=t, + ode_init_kwargs={'static_params': static_params}) + + phase.set_time_options(fix_initial=True, duration_bounds=(1, 2), duration_ref=1) + phase.add_state("x", fix_initial=True, rate_source="x_dot") + + A_mat = np.array( + [ + [0, 1], + [-1, 0] + ] + ) + + # argument "dynamic" doesn't seem to help + phase.add_parameter("A", val=A_mat, targets=["A"], static_target=static_params) + phase.add_objective("time", loc="final", scaler=1) + + traj = dm.Trajectory() + traj.add_phase("phase0", phase) + + prob.model.add_subsystem("traj", traj) + + prob.driver.declare_coloring() + prob.setup(force_alloc_complex=True) + phase.set_state_val('x', vals=[[1, 0], [1, 0]]) + + dm.run_problem(prob, run_driver=True, simulate=True, make_plots=True) + t_f = prob.get_val('traj.phase0.timeseries.time')[-1] + final_state = prob.get_val('traj.phase0.timeseries.x')[-1, :] + assert_near_equal(final_state, np.array([np.cos(t_f), -np.sin(t_f)]).ravel(), + tolerance=1e-5) + + def test_matrix_static_param(self): + + from dymos.examples.oscillator.oscillator_ode import OscillatorVectorODE + + # Instantiate an OpenMDAO Problem instance. + prob = om.Problem() + prob.driver = om.ScipyOptimizeDriver() + prob.driver.options["optimizer"] = 'SLSQP' + + static_params = True + + t = dm.Radau(num_segments=2, order=3) + phase = dm.Phase(ode_class=OscillatorVectorODE, transcription=t, + ode_init_kwargs={'static_params': static_params}) + + phase.set_time_options(fix_initial=True, duration_bounds=(1, 2), duration_ref=1) + phase.add_state("x", fix_initial=True, rate_source="x_dot") + + A_mat = np.array( + [ + [0, 1], + [-1, 0] + ] + ) + + # argument "dynamic" doesn't seem to help + phase.add_parameter("A", val=A_mat, targets=["A"], static_target=static_params) + phase.add_objective("time", loc="final", scaler=1) + + traj = dm.Trajectory() + traj.add_phase("phase0", phase) + + prob.model.add_subsystem("traj", traj) + + prob.driver.declare_coloring() + prob.setup(force_alloc_complex=True) + phase.set_state_val('x', vals=[[1, 0], [1, 0]]) + + dm.run_problem(prob, run_driver=True, simulate=True, make_plots=True) + t_f = prob.get_val('traj.phase0.timeseries.time')[-1] + final_state = prob.get_val('traj.phase0.timeseries.x')[-1, :] + assert_near_equal(final_state, np.array([np.cos(t_f), -np.sin(t_f)]).ravel(), + tolerance=1e-5) diff --git a/dymos/transcriptions/explicit_shooting/explicit_shooting.py b/dymos/transcriptions/explicit_shooting/explicit_shooting.py index a83572864..a194f9e57 100644 --- a/dymos/transcriptions/explicit_shooting/explicit_shooting.py +++ b/dymos/transcriptions/explicit_shooting/explicit_shooting.py @@ -283,10 +283,6 @@ def configure_states(self, phase): ref0=options['ref0'], ref=options['ref']) - def _get_ode(self, phase): - ode = phase._get_subsystem('ode') - return ode - def setup_ode(self, phase): """ Setup the ode for this transcription. diff --git a/dymos/transcriptions/explicit_shooting/ode_evaluation_group.py b/dymos/transcriptions/explicit_shooting/ode_evaluation_group.py index a7caef380..2aa64ad0e 100644 --- a/dymos/transcriptions/explicit_shooting/ode_evaluation_group.py +++ b/dymos/transcriptions/explicit_shooting/ode_evaluation_group.py @@ -218,7 +218,6 @@ def _configure_states(self): self.add_constraint(f'state_rate_collector.state_rates:{name}_rate') def _configure_params(self): - vec_size = self._vec_size ode_inputs = get_promoted_vars(self.ode, iotypes='input', metadata_keys=['shape', 'units', 'val', 'tags']) for name, options in self._parameter_options.items(): @@ -242,13 +241,8 @@ def _configure_params(self): # Promote targets from the ODE for tgt in targets: if tgt in options['static_targets']: - src_idxs = None shape = None - else: - src_rows = np.zeros(vec_size, dtype=int) - src_idxs = om.slicer[src_rows, ...] - - self.promotes('ode', inputs=[(tgt, var_name)], src_indices=src_idxs, + self.promotes('ode', inputs=[(tgt, var_name)], src_shape=shape) if targets: self.set_input_defaults(name=var_name,