Skip to content

Commit

Permalink
updates per Bret's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
robfalck committed Aug 23, 2023
1 parent 651f199 commit 103b9c2
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 35 deletions.
10 changes: 5 additions & 5 deletions dymos/trajectory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..phase.analytic_phase import AnalyticPhase
from ..phase.options import TrajParameterOptionsDictionary
from ..transcriptions.common import ParameterComp
from ..utils.misc import get_rate_units, _unspecified
from ..utils.misc import get_rate_units, _unspecified, _none_or_unspecified
from ..utils.introspection import get_promoted_vars, get_source_metadata, _get_common_metadata
from .._options import options as dymos_options

Expand Down Expand Up @@ -290,10 +290,10 @@ def _get_phase_parameters(self):
"""
phase_param_options = {}
for phs in self.phases._subsystems_myproc:
phase_param_options.update({phs.name: phs.parameter_options})
phase_param_options[phs.name] = phs.parameter_options

if self.comm.size > 1:
data = self.comm.gather(phase_param_options, root=0)
data = self.comm.allgather(phase_param_options)
if data:
for d in data:
phase_param_options.update(d)
Expand Down Expand Up @@ -499,7 +499,7 @@ def _configure_parameters(self):
if options['units'] is _unspecified:
options['units'] = _get_common_metadata(targets, metadata_key='units')

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
options['shape'] = _get_common_metadata(targets, metadata_key='shape')

param_comp = self._get_subsystem('param_comp')
Expand Down Expand Up @@ -535,7 +535,7 @@ def _configure_phase_options_dicts(self):

all_ranks = self.comm.allgather(options['shape'])
for item in all_ranks:
if item not in {None, _unspecified}:
if item not in _none_or_unspecified:
options['shape'] = item
break
else:
Expand Down
4 changes: 2 additions & 2 deletions dymos/transcriptions/common/parameter_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from openmdao.core.explicitcomponent import ExplicitComponent
from ...utils.misc import _unspecified
from ...utils.misc import _none_or_unspecified
from ..._options import options as dymos_options


Expand Down Expand Up @@ -133,7 +133,7 @@ def add_parameter(self, name, val=1.0, shape=None, output_name=None,

_val = np.asarray(val)

if shape in {None, _unspecified}:
if shape in _none_or_unspecified:
_shape = (1,)
size = _val.size
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
configure_time_introspection, configure_parameters_introspection, \
configure_states_discovery, configure_states_introspection, _get_targets_metadata, \
_get_common_metadata, get_promoted_vars
from ...utils.misc import get_rate_units, _unspecified
from ...utils.misc import get_rate_units, _unspecified, _none_or_unspecified


class ODEEvaluationGroup(om.Group):
Expand Down Expand Up @@ -215,7 +215,7 @@ def _configure_params(self):
else:
units = options['units']

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
shape = _get_common_metadata(targets, 'shape')
else:
shape = options['shape']
Expand Down
4 changes: 2 additions & 2 deletions dymos/transcriptions/transcription_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .common import ControlGroup, PolynomialControlGroup, ParameterComp
from ..utils.constants import INF_BOUND
from ..utils.indexing import get_constraint_flat_idxs
from ..utils.misc import _unspecified
from ..utils.misc import _none_or_unspecified
from ..utils.introspection import configure_states_introspection, get_promoted_vars, \
configure_states_discovery

Expand Down Expand Up @@ -105,7 +105,7 @@ def configure_time(self, phase):
time_options = phase.time_options

# Determine the time unit.
if time_options['units'] in {None, _unspecified}:
if time_options['units'] in _none_or_unspecified:
if time_options['targets']:
ode = phase._get_subsystem(self._rhs_source)

Expand Down
45 changes: 21 additions & 24 deletions dymos/utils/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import openmdao.api as om
import numpy as np
from openmdao.utils.array_utils import shape_to_len
from openmdao.utils.general_utils import ensure_compatible
from dymos.utils.misc import _unspecified
from dymos.utils.misc import _unspecified, _none_or_unspecified
from .._options import options as dymos_options
from ..phase.options import StateOptionsDictionary, TimeseriesOutputOptionsDictionary
from .misc import get_rate_units
Expand Down Expand Up @@ -316,7 +315,7 @@ def configure_controls_introspection(control_options, ode, time_units='s'):
if options['units'] is _unspecified:
options['units'] = _get_common_metadata(targets, metadata_key='units')

if options['shape'] in {_unspecified, None}:
if options['shape'] in _none_or_unspecified:
shape = _get_common_metadata(targets, metadata_key='shape')
if len(shape) == 1:
options['shape'] = (1,)
Expand All @@ -337,7 +336,7 @@ def configure_controls_introspection(control_options, ode, time_units='s'):
rate_target_units = _get_common_metadata(rate_targets, metadata_key='units')
options['units'] = time_units if rate_target_units is None else f'{rate_target_units}*{time_units}'

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
shape = _get_common_metadata(rate_targets, metadata_key='shape')
if len(shape) == 1:
options['shape'] = (1,)
Expand All @@ -359,7 +358,7 @@ def configure_controls_introspection(control_options, ode, time_units='s'):
options['units'] = f'{time_units**2}' if rate2_target_units is None \
else f'{rate2_target_units}*{time_units}**2'

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
shape = _get_common_metadata(rate2_targets, metadata_key='shape')
if len(shape) == 1:
options['shape'] = (1,)
Expand Down Expand Up @@ -412,29 +411,27 @@ def configure_parameters_introspection(parameter_options, ode):
options['units'] = _get_common_metadata(targets, metadata_key='units')

# Check that all targets have the same shape.
static_shapes = {}
dynamic_shapes = {}
tgt_shapes = {}
# First find the shapes of the static targets
for tgt, meta in targets.items():
if tgt in options['static_targets']:
static_shapes[tgt] = meta['shape']
tgt_shapes[tgt] = meta['shape']
else:
if len(meta['shape']) == 1:
dynamic_shapes[tgt] = (1,)
tgt_shapes[tgt] = (1,)
else:
dynamic_shapes[tgt] = meta['shape'][1:]
all_shapes = {**dynamic_shapes, **static_shapes}
tgt_shapes[tgt] = meta['shape'][1:]
# Check that they're unique
if len(set(all_shapes.values())) > 1:
if len(set(tgt_shapes.values())) > 1:
raise RuntimeError(f'Invalid targets for parameter `{name}`.\n'
f'Targets have multiple shapes.\n'
f'{all_shapes}')
elif len(set(all_shapes.values())) == 1:
introspected_shape = next(iter(set(all_shapes.values())))
f'{tgt_shapes}')
elif len(set(tgt_shapes.values())) == 1:
introspected_shape = next(iter(set(tgt_shapes.values())))
else:
introspected_shape = None

if options['shape'] in {_unspecified, None}:
if options['shape'] in _none_or_unspecified:
if isinstance(options['val'], Number):
options['shape'] = introspected_shape
else:
Expand All @@ -444,7 +441,7 @@ def configure_parameters_introspection(parameter_options, ode):
raise RuntimeError(f'Shape provided to parameter `{name}` differs from its targets.\n'
f'Given shape: {options["shape"]}\n'
f'Target shapes:\n'
f'{all_shapes}')
f'{tgt_shapes}')

options['val'], options['shape'] = ensure_compatible(name, options['val'], options['shape'])

Expand Down Expand Up @@ -534,7 +531,7 @@ def configure_states_introspection(state_options, time_options, control_options,
if options['units'] is _unspecified:
options['units'] = _get_common_metadata(targets, metadata_key='units')

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
shape = _get_common_metadata(targets, metadata_key='shape')
if len(shape) == 1:
options['shape'] = (1,)
Expand Down Expand Up @@ -593,7 +590,7 @@ def configure_states_introspection(state_options, time_options, control_options,
rate_src_shape = (1,)
rate_src_units = None

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
options['shape'] = rate_src_shape

if options['units'] is _unspecified:
Expand Down Expand Up @@ -643,7 +640,7 @@ def configure_analytic_states_introspection(state_options, ode):
raise RuntimeError(f'ODE output {source} is tagged with `dymos.static_output` and cannot be used as a '
f'state variable in an AnalyticPhase.')

if options['shape'] in {None, _unspecified}:
if options['shape'] in _none_or_unspecified:
options['shape'] = src_shape

if options['units'] is _unspecified:
Expand Down Expand Up @@ -879,7 +876,7 @@ def configure_timeseries_expr_introspection(phase):
expr_reduced = expr

units = output_options['units'] if output_options['units'] is not _unspecified else None
shape = output_options['shape'] if output_options['shape'] not in {_unspecified, None} else (1,)
shape = output_options['shape'] if output_options['shape'] not in _none_or_unspecified else (1,)

abs_names = [x.strip() for x in re.findall(var_names_regex, expr)
if not x.endswith('(') and not x.endswith(':')]
Expand Down Expand Up @@ -1193,7 +1190,7 @@ def _get_common_metadata(targets, metadata_key):
ValueError
ValueError is raised if the targets do not all have the same metadata value.
"""
meta_set = {meta[metadata_key] for tgt, meta in targets.items()}
meta_set = {meta[metadata_key] for meta in targets.values()}

if len(meta_set) == 1:
return next(iter(meta_set))
Expand Down Expand Up @@ -1249,12 +1246,12 @@ def get_source_metadata(ode, src, user_units=_unspecified, user_shape=_unspecifi
if src not in ode_outputs:
raise ValueError(f"Unable to find the source '{src}' in the ODE.")

if user_units in {None, _unspecified}:
if user_units in _none_or_unspecified:
meta['units'] = ode_outputs[src]['units']
else:
meta['units'] = user_units

if user_shape in {None, _unspecified}:
if user_shape in _none_or_unspecified:
ode_shape = ode_outputs[src]['shape']
meta['shape'] = (1,) if len(ode_shape) == 1 else ode_shape[1:]
else:
Expand Down
1 change: 1 addition & 0 deletions dymos/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# unique object to check if default is given (when None is an allowed value)
_unspecified = _ReprClass("unspecified")
_none_or_unspecified = {None, _unspecified}


def get_rate_units(units, time_units, deriv=1):
Expand Down

0 comments on commit 103b9c2

Please sign in to comment.