From 0b9557f64fb3368b8cc54d14efcd103a45080208 Mon Sep 17 00:00:00 2001 From: G Webb Date: Fri, 2 Aug 2024 15:43:56 +0100 Subject: [PATCH 1/7] Make the measurement label for plotters consistent --- stonesoup/plotter.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index 529686530..21c089819 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -285,7 +285,11 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurements_handle = Line2D([], [], linestyle='', **measurement_kwargs) # Generate legend items for measurements - self.legend_dict[measurements_label] = measurements_handle + if plot_clutter: + name = measurements_label + "
(Detections)" + else: + name = measurements_label + self.legend_dict[name] = measurements_handle if plot_clutter: clutter_kwargs = kwargs.copy() @@ -293,10 +297,10 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, clutter_array = np.array(list(plot_clutter.values())) artists.append(self.ax.scatter(*clutter_array.T, **clutter_kwargs)) clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs) - clutter_label = "Clutter" # Generate legend items for clutter - self.legend_dict[clutter_label] = clutter_handle + name = measurements_label + "
(Clutter)" + self.legend_dict[name] = clutter_handle # Generate legend artists.append(self.ax.legend(handles=self.legend_dict.values(), @@ -1146,7 +1150,10 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, convert_measurements) if plot_detections: - name = measurements_label + "
(Detections)" + if plot_clutter: + name = measurements_label + "
(Detections)" + else: + name = measurements_label measurement_kwargs = dict( mode='markers', marker=dict(color='#636EFA'), name=name, legendgroup=name, legendrank=200) @@ -1708,7 +1715,10 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, range_mapping = None if plot_detections: - name = measurements_label + "
(Detections)" + if plot_clutter: + name = measurements_label + "
(Detections)" + else: + name = measurements_label measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200) merge(measurement_kwargs, kwargs) plotting_data = [State(state_vector=plotting_state_vector, @@ -1957,7 +1967,7 @@ def plot_state_mutable_sequence(self, state_mutable_sequences, mapping: List[int )) def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="", convert_measurements=True, **kwargs): + measurements_label="Measurements", convert_measurements=True, **kwargs): """Plots measurements Plots detections and clutter, generating a legend automatically. Detections are plotted as @@ -1977,7 +1987,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, User-defined measurement model to be used in finding measurement state inverses if they cannot be found from the measurements themselves. measurements_label: str - Label for measurements. Default will be "Detections" or "Clutter" + Label for measurements. Default is "Detections". convert_measurements: bool Should the measurements be converted from measurement space to state space before being plotted. Default is True @@ -2002,17 +2012,18 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurement_model, convert_measurements) - if measurements_label != "": - measurements_label = measurements_label + " " - if plot_detections: + if plot_clutter: + name = measurements_label + "
(Detections)" + else: + name = measurements_label detection_kwargs = dict(linestyle='', marker='o', color='b') detection_kwargs.update(kwargs) self.plotting_data.append(_AnimationPlotterDataClass( plotting_data=[State(state_vector=plotting_state_vector, timestamp=detection.timestamp) for detection, plotting_state_vector in plot_detections.items()], - plotting_label=measurements_label + "Detections", + plotting_label=name, plotting_keyword_arguments=detection_kwargs )) @@ -2023,7 +2034,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, plotting_data=[State(state_vector=plotting_state_vector, timestamp=detection.timestamp) for detection, plotting_state_vector in plot_clutter.items()], - plotting_label=measurements_label + "Clutter", + plotting_label=measurements_label + "
(Clutter)", plotting_keyword_arguments=clutter_kwargs )) @@ -2636,7 +2647,10 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, trace_base = len(self.fig.data) # initialise detections - name = measurements_label + "
(Detections)" + if plot_clutter: + name = measurements_label + "
(Detections)" + else: + name = measurements_label measurement_kwargs = dict(x=[], y=[], mode='markers', name=name, legendgroup=name, From d8c895dd1df448db2f8f37fcad4b96f11850955e Mon Sep 17 00:00:00 2001 From: G Webb Date: Fri, 2 Aug 2024 15:56:11 +0100 Subject: [PATCH 2/7] Fix tests in stonesoup/tests/test_plotter.py due to measurement label changing --- stonesoup/tests/test_plotter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stonesoup/tests/test_plotter.py b/stonesoup/tests/test_plotter.py index 6e36fc327..eb0ec2fe3 100644 --- a/stonesoup/tests/test_plotter.py +++ b/stonesoup/tests/test_plotter.py @@ -321,7 +321,7 @@ def test_plotterly_wrong_dimension(dim, mapping): @pytest.mark.parametrize("labels", [ None, ["Tracks"], ["Ground Truth", "Tracks"], - ["Ground Truth", "Measurements
(Detections)", "Tracks"]]) + ["Ground Truth", "Measurements", "Tracks"]]) def test_hide_plot(labels): plotter = Plotterly() plotter.plot_ground_truths(truth, [0, 1]) @@ -348,7 +348,7 @@ def test_hide_plot(labels): @pytest.mark.parametrize("labels", [ None, ["Tracks"], ["Ground Truth", "Tracks"], - ["Ground Truth", "Measurements
(Detections)", "Tracks"]]) + ["Ground Truth", "Measurements", "Tracks"]]) def test_show_plot(labels): plotter = Plotterly() plotter.plot_ground_truths(truth, [0, 1]) From b0695c6a840b4c3df106418dfbae2cae6f7041c2 Mon Sep 17 00:00:00 2001 From: G Webb Date: Thu, 15 Aug 2024 14:13:19 +0100 Subject: [PATCH 3/7] Replaced 'truths_label', 'measurements_label', 'track_label' and 'sensor_label' with 'label' in plotter.py --- stonesoup/plotter.py | 288 ++++++++++++++++++++++++++++++------------- 1 file changed, 201 insertions(+), 87 deletions(-) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index 21c089819..ef3d8fab4 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -53,21 +53,21 @@ class Dimension(IntEnum): class _Plotter(ABC): @abstractmethod - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs): raise NotImplementedError @abstractmethod def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", **kwargs): + label="Measurements", **kwargs): raise NotImplementedError @abstractmethod - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks", **kwargs): raise NotImplementedError @abstractmethod - def plot_sensors(self, sensors, mapping, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, mapping, label="Sensors", **kwargs): raise NotImplementedError def _conv_measurements(self, measurements, mapping, measurement_model=None, @@ -166,7 +166,7 @@ def __init__(self, dimension=Dimension.TWO, **kwargs): # This is new compared to plotter.py self.legend_dict = {} # create an empty dictionary to hold legend entries - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs): """Plots ground truth(s) Plots each ground truth path passed in to :attr:`truths` and generates a legend @@ -183,7 +183,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa for iteration. mapping: list List of items specifying the mapping of the position components of the state space. - truths_label: str + label: str Label for truth data. Default is "Ground Truth" \\*\\*kwargs: dict Additional arguments to be passed to plot function. Default is ``linestyle="--"``. @@ -192,7 +192,13 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa ------- : list of :class:`matplotlib.artist.Artist` List of artists that have been added to the axis. + + + .. deprecated:: 1.5 + ``label`` has replaced ``truths_label``. In the current implementation ``truths_label`` + overrides ``label``. However, use of ``truths_label`` may be removed in the future. """ + label = kwargs.pop('truths_label', None) or label truths_kwargs = dict(linestyle="--") truths_kwargs.update(kwargs) if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): @@ -219,14 +225,14 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa else: colour = "black" truths_handle = Line2D([], [], linestyle=truths_kwargs['linestyle'], color=colour) - self.legend_dict[truths_label] = truths_handle + self.legend_dict[label] = truths_handle # Generate legend artists.append(self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())) return artists def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): + label="Measurements", convert_measurements=True, **kwargs): """Plots measurements Plots detections and clutter, generating a legend automatically. Detections are plotted as @@ -245,7 +251,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurement_model : :class:`~.Model`, optional User-defined measurement model to be used in finding measurement state inverses if they cannot be found from the measurements themselves. - measurements_label : str + label : str Label for the measurements. Default is "Measurements". convert_measurements : bool Should the measurements be converted from measurement space to state space before @@ -258,8 +264,14 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, ------- : list of :class:`matplotlib.artist.Artist` List of artists that have been added to the axis. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``measurements_label``. In the current implementation + ``measurements_label`` overrides ``label``. However, use of ``measurements_label`` + may be removed in the future. + """ + label = kwargs.pop('measurements_label', None) or label measurement_kwargs = dict(marker='o', color='b') measurement_kwargs.update(kwargs) @@ -286,9 +298,9 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, # Generate legend items for measurements if plot_clutter: - name = measurements_label + "
(Detections)" + name = label + "
(Detections)" else: - name = measurements_label + name = label self.legend_dict[name] = measurements_handle if plot_clutter: @@ -299,7 +311,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs) # Generate legend items for clutter - name = measurements_label + "
(Clutter)" + name = label + "
(Clutter)" self.legend_dict[name] = clutter_handle # Generate legend @@ -307,7 +319,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, labels=self.legend_dict.keys())) return artists - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks", err_freq=1, same_color=False, **kwargs): """Plots track(s) @@ -333,7 +345,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ If True, function plots uncertainty ellipses or bars. particle : bool If True, function plots particles. - track_label: str + label: str Label to apply to all tracks for legend. err_freq: int Frequency of error bar plotting on tracks. Default value is 1, meaning @@ -348,8 +360,14 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ ------- : list of :class:`matplotlib.artist.Artist` List of artists that have been added to the axis. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``track_label``. In the current implementation + ``track_label`` overrides ``label``. However, use of ``track_label`` + may be removed in the future. + """ + label = kwargs.pop('track_label', None) or label tracks_kwargs = dict(linestyle='-', marker="s", color=None) tracks_kwargs.update(kwargs) if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): @@ -396,7 +414,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ # Generate legend items for track track_handle = Line2D([], [], linestyle=tracks_kwargs['linestyle'], marker=tracks_kwargs['marker'], color=tracks_kwargs['color']) - self.legend_dict[track_label] = track_handle + self.legend_dict[label] = track_handle if uncertainty: if self.dimension is Dimension.TWO: # Plot uncertainty ellipses @@ -487,7 +505,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ return artists - def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, mapping=None, label="Sensors", **kwargs): """Plots sensor(s) Plots sensors. Users can change the color and marker of sensors using keyword @@ -500,7 +518,7 @@ def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): mapping: list List of items specifying the mapping of the position components of the sensor's position. Default is either [0, 1] or [0, 1, 2] depending on `self.dimension` - sensor_label: str + label: str Label to apply to all sensors for legend. \\*\\*kwargs: dict Additional arguments to be passed to plot function for sensors. Defaults are @@ -510,8 +528,14 @@ def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): ------- : list of :class:`matplotlib.artist.Artist` List of artists that have been added to the axis. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``sensor_label``. In the current implementation + ``sensor_label`` overrides ``label``. However, use of ``sensor_label`` + may be removed in the future. + """ + label = kwargs.pop('sensor_label', None) or label sensor_kwargs = dict(marker='x', color='black') sensor_kwargs.update(kwargs) @@ -534,7 +558,7 @@ def plot_sensors(self, sensors, mapping=None, sensor_label="Sensors", **kwargs): **sensor_kwargs)) else: raise NotImplementedError('Unsupported dimension type for sensor plotting') - self.legend_dict[sensor_label] = Line2D([], [], linestyle='', **sensor_kwargs) + self.legend_dict[label] = Line2D([], [], linestyle='', **sensor_kwargs) artists.append(self.ax.legend(handles=self.legend_dict.values(), labels=self.legend_dict.keys())) return artists @@ -1035,7 +1059,7 @@ def _check_mapping(self, mapping): elif len(mapping) != self.dimension: raise TypeError("Plotter dimension is not same as the mapping dimension.") - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs): """Plots ground truth(s) Plots each ground truth path passed in to :attr:`truths` and generates a legend @@ -1052,20 +1076,27 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa set to allow for iteration. mapping: list List of items specifying the mapping of the position components of the state space. - truths_label: str + label: str Label for truth data. Default is "Ground Truth" \\*\\*kwargs: dict Additional arguments to be passed to scatter function. Default is ``line=dict(dash="dash")``. + + + .. deprecated:: 1.5 + ``label`` has replaced ``truths_label``. In the current implementation + ``truths_label`` overrides ``label``. However, use of ``truths_label`` + may be removed in the future. """ + label = kwargs.pop('truths_label', None) or label if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): truths = {truths} self._check_mapping(mapping) # ensure mapping is compatible with plotter dimension truths_kwargs = dict( - mode="lines", line=dict(dash="dash"), legendgroup=truths_label, legendrank=100, - name=truths_label) + mode="lines", line=dict(dash="dash"), legendgroup=label, legendrank=100, + name=label) if self.dimension == 3: # make ground truth line thicker so easier to see in 3d plot truths_kwargs.update(dict(line=dict(width=8, dash="longdashdot"))) @@ -1105,7 +1136,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa **scatter_kwargs) def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): + label="Measurements", convert_measurements=True, **kwargs): """Plots measurements Plots detections and clutter, generating a legend automatically. Detections are plotted as @@ -1124,7 +1155,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurement_model : :class:`~.Model`, optional User-defined measurement model to be used in finding measurement state inverses if they cannot be found from the measurements themselves. - measurements_label : str + label : str Label for the measurements. Default is "Measurements". convert_measurements: bool Should the measurements be converted from measurement space to state space before @@ -1132,8 +1163,14 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, \\*\\*kwargs: dict Additional arguments to be passed to scatter function for detections. Defaults are ``marker=dict(color="#636EFA")``. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``measurements_label``. In the current implementation + ``measurements_label`` overrides ``label``. However, use of ``measurements_label`` + may be removed in the future. + """ + label = kwargs.pop('measurements_label', None) or label if not isinstance(measurements, Collection): measurements = {measurements} @@ -1151,9 +1188,9 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_detections: if plot_clutter: - name = measurements_label + "
(Detections)" + name = label + "
(Detections)" else: - name = measurements_label + name = label measurement_kwargs = dict( mode='markers', marker=dict(color='#636EFA'), name=name, legendgroup=name, legendrank=200) @@ -1193,7 +1230,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, ) if plot_clutter: - name = measurements_label + "
(Clutter)" + name = label + "
(Clutter)" clutter_kwargs = dict( mode='markers', marker=dict(symbol="star-triangle-up", color='#FECB52'), name=name, legendgroup=name, legendrank=210) @@ -1255,7 +1292,7 @@ def get_next_color(self): color_index = figure_index % max_index return colorway[color_index] - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks", ellipse_points=30, err_freq=1, same_color=False, **kwargs): """Plots track(s) @@ -1277,7 +1314,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ If True, function plots uncertainty ellipses. particle : bool If True, function plots particles. - track_label: str + label: str Label to apply to all tracks for legend. ellipse_points: int Number of points for polygon approximating ellipse shape @@ -1290,7 +1327,14 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ Additional arguments to be passed to scatter function. Defaults are ``marker=dict(symbol='square')`` for :class:`~.Update` and ``marker=dict(symbol='circle')`` for other states. + + + .. deprecated:: 1.5 + ``label`` has replaced ``track_label``. In the current implementation + ``track_label`` overrides ``label``. However, use of ``track_label`` + may be removed in the future. """ + label = kwargs.pop('track_label', None) or label if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): tracks = {tracks} # Make a set of length 1 @@ -1298,7 +1342,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ # Plot tracks track_colors = {} - track_kwargs = dict(mode='markers+lines', legendgroup=track_label, legendrank=300) + track_kwargs = dict(mode='markers+lines', legendgroup=label, legendrank=300) if self.dimension == 3: # change visuals to work well in 3d track_kwargs.update(dict(line=dict(width=7)), marker=dict(size=4)) @@ -1319,7 +1363,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ scatter_kwargs = track_kwargs.copy() scatter_kwargs['name'] = track.id if add_legend: - scatter_kwargs['name'] = track_label + scatter_kwargs['name'] = label scatter_kwargs['showlegend'] = True add_legend = False else: @@ -1473,7 +1517,7 @@ def func3(x): points = rotational_matrix @ points.T return points + state.mean[mapping[:2], :] - def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, mapping=[0, 1], label="Sensors", **kwargs): """Plots sensor(s) Plots sensors. Users can change the color and marker of sensors using keyword @@ -1486,13 +1530,19 @@ def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs mapping: list List of items specifying the mapping of the position components of the sensor's position. - sensor_label: str + label: str Label to apply to all sensors for legend. \\*\\*kwargs: dict Additional arguments to be passed to scatter function for sensors. Defaults are ``marker=dict(symbol='x', color='black')``. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``sensor_label``. In the current implementation + ``sensor_label`` overrides ``label``. However, use of ``sensor_label`` + may be removed in the future. + """ + label = kwargs.pop('sensor_label', None) or label if not isinstance(sensors, Collection): sensors = {sensors} @@ -1502,10 +1552,10 @@ def plot_sensors(self, sensors, mapping=[0, 1], sensor_label="Sensors", **kwargs raise NotImplementedError sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), - legendgroup=sensor_label, legendrank=50) + legendgroup=label, legendrank=50) merge(sensor_kwargs, kwargs) - sensor_kwargs['name'] = sensor_label + sensor_kwargs['name'] = label if sensor_kwargs['legendgroup'] not in {trace.legendgroup for trace in self.fig.data}: sensor_kwargs['showlegend'] = True @@ -1633,7 +1683,7 @@ def plot_state_sequence(self, state_sequences, angle_mapping: int, range_mapping theta=bearings, **scatter_kwargs) self.fig.add_trace(polar_plot) - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwargs): + def plot_ground_truths(self, truths, mapping, label="Ground Truth", **kwargs): """Plots ground truth(s) Plots each ground truth path passed in to :attr:`truths` and generates a legend @@ -1650,12 +1700,19 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa set to allow for iteration. mapping: list List of items specifying the mapping of the position components of the state space. - truths_label: str + label: str Label for truth data. Default is "Ground Truth". \\*\\*kwargs: dict Additional arguments to be passed to scatter function. Default is ``line=dict(dash="dash")``. + + + .. deprecated:: 1.5 + ``label`` has replaced ``truths_label``. In the current implementation + ``truths_label`` overrides ``label``. However, use of ``truths_label`` + may be removed in the future. """ + label = kwargs.pop('truths_label', None) or label truths_kwargs = dict(mode="lines", line=dict(dash="dash"), legendrank=100) merge(truths_kwargs, kwargs) angle_mapping = mapping[0] @@ -1664,10 +1721,10 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", **kwa else: range_mapping = None self.plot_state_sequence(state_sequences=truths, angle_mapping=angle_mapping, - range_mapping=range_mapping, label=truths_label, **truths_kwargs) + range_mapping=range_mapping, label=label, **truths_kwargs) def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): + label="Measurements", convert_measurements=True, **kwargs): """Plots measurements Plots detections and clutter, generating a legend automatically. Detections are plotted as @@ -1686,15 +1743,21 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurement_model : :class:`~.Model`, optional User-defined measurement model to be used in finding measurement state inverses if they cannot be found from the measurements themselves. - measurements_label : str + label : str Label for the measurements. Default is "Measurements". convert_measurements: bool Should the measurements be converted before being plotted. Default is True. \\*\\*kwargs: dict Additional arguments to be passed to scatter function for detections. Defaults are ``marker=dict(color="#636EFA")``. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``measurements_label``. In the current implementation + ``measurements_label`` overrides ``label``. However, use of ``measurements_label`` + may be removed in the future. + """ + label = kwargs.pop('measurements_label', None) or label if not isinstance(measurements, Collection): measurements = {measurements} @@ -1716,9 +1779,9 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_detections: if plot_clutter: - name = measurements_label + "
(Detections)" + name = label + "
(Detections)" else: - name = measurements_label + name = label measurement_kwargs = dict(mode='markers', marker=dict(color='#636EFA'), legendrank=200) merge(measurement_kwargs, kwargs) plotting_data = [State(state_vector=plotting_state_vector, @@ -1730,7 +1793,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, **measurement_kwargs) if plot_clutter: - name = measurements_label + "
(Clutter)" + name = label + "
(Clutter)" clutter_kwargs = dict(mode='markers', legendrank=210, marker=dict(symbol="star-triangle-up", color='#FECB52')) merge(clutter_kwargs, kwargs) @@ -1742,7 +1805,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, range_mapping=range_mapping, label=name, **clutter_kwargs) - def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_label="Tracks", + def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, label="Tracks", **kwargs): """Plots track(s) @@ -1764,12 +1827,19 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ If True, function plots uncertainty ellipses. particle : bool If True, function plots particles. - track_label: str + label: str Label to apply to all tracks for legend. \\*\\*kwargs: dict Additional arguments to be passed to scatter function. Defaults are ``mode='markers+lines'``. + + + .. deprecated:: 1.5 + ``label`` has replaced ``track_label``. In the current implementation + ``track_label`` overrides ``label``. However, use of ``track_label`` + may be removed in the future. """ + label = kwargs.pop('track_label', None) or label if uncertainty or particle: raise NotImplementedError @@ -1781,9 +1851,9 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, particle=False, track_ else: range_mapping = None self.plot_state_sequence(state_sequences=tracks, angle_mapping=angle_mapping, - range_mapping=range_mapping, label=track_label, **track_kwargs) + range_mapping=range_mapping, label=label, **track_kwargs) - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, label="Sensors", **kwargs): raise NotImplementedError @@ -1869,7 +1939,7 @@ def save(self, filename='example.mp4', **kwargs): self.animation_output.save(filename, **kwargs) - def plot_ground_truths(self, truths, mapping: List[int], truths_label: str = "Ground Truth", + def plot_ground_truths(self, truths, mapping: List[int], label: str = "Ground Truth", **kwargs): """Plots ground truth(s) @@ -1887,18 +1957,24 @@ def plot_ground_truths(self, truths, mapping: List[int], truths_label: str = "Gr for iteration. mapping: list List of items specifying the mapping of the position components of the state space. - truths_label: str + label: str Label for truth data. Default is "Ground Truth" \\*\\*kwargs: dict Additional arguments to be passed to plot function. Default is ``linestyle="--"``. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``truths_label``. In the current implementation + ``truths_label`` overrides ``label``. However, use of ``truths_label`` + may be removed in the future. + """ + label = kwargs.pop('truths_label', None) or label truths_kwargs = dict(linestyle="--") truths_kwargs.update(kwargs) - self.plot_state_mutable_sequence(truths, mapping, truths_label, **truths_kwargs) + self.plot_state_mutable_sequence(truths, mapping, label, **truths_kwargs) def plot_tracks(self, tracks, mapping: List[int], uncertainty=False, particle=False, - track_label="Tracks", **kwargs): + label="Tracks", **kwargs): """Plots track(s) Plots each track generated, generating a legend automatically. Tracks are plotted as solid @@ -1917,18 +1993,25 @@ def plot_tracks(self, tracks, mapping: List[int], uncertainty=False, particle=Fa Currently not implemented. If True, an error is raised particle : bool Currently not implemented. If True, an error is raised - track_label: str + label: str Label to apply to all tracks for legend. \\*\\*kwargs: dict Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. + + + .. deprecated:: 1.5 + ``label`` has replaced ``track_label``. In the current implementation + ``track_label`` overrides ``label``. However, use of ``track_label`` + may be removed in the future. """ + label = kwargs.pop('track_label', None) or label if uncertainty or particle: raise NotImplementedError tracks_kwargs = dict(linestyle='-', marker="s", color=None) tracks_kwargs.update(kwargs) - self.plot_state_mutable_sequence(tracks, mapping, track_label, **tracks_kwargs) + self.plot_state_mutable_sequence(tracks, mapping, label, **tracks_kwargs) def plot_state_mutable_sequence(self, state_mutable_sequences, mapping: List[int], label: str, **plotting_kwargs): @@ -1967,7 +2050,7 @@ def plot_state_mutable_sequence(self, state_mutable_sequences, mapping: List[int )) def plot_measurements(self, measurements, mapping, measurement_model=None, - measurements_label="Measurements", convert_measurements=True, **kwargs): + label="Measurements", convert_measurements=True, **kwargs): """Plots measurements Plots detections and clutter, generating a legend automatically. Detections are plotted as @@ -1986,7 +2069,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, measurement_model : :class:`~.Model`, optional User-defined measurement model to be used in finding measurement state inverses if they cannot be found from the measurements themselves. - measurements_label: str + label: str Label for measurements. Default is "Detections". convert_measurements: bool Should the measurements be converted from measurement space to state space before @@ -1994,8 +2077,14 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, \\*\\*kwargs: dict Additional arguments to be passed to plot function for detections. Defaults are ``marker='o'`` and ``color='b'``. - """ + + .. deprecated:: 1.5 + ``label`` has replaced ``measurements_label``. In the current implementation + ``measurements_label`` overrides ``label``. However, use of ``measurements_label`` + may be removed in the future. + """ + label = kwargs.pop('measurements_label', None) or label measurement_kwargs = dict(marker='o', color='b') measurement_kwargs.update(kwargs) @@ -2014,9 +2103,9 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_detections: if plot_clutter: - name = measurements_label + "
(Detections)" + name = label + "
(Detections)" else: - name = measurements_label + name = label detection_kwargs = dict(linestyle='', marker='o', color='b') detection_kwargs.update(kwargs) self.plotting_data.append(_AnimationPlotterDataClass( @@ -2034,11 +2123,11 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, plotting_data=[State(state_vector=plotting_state_vector, timestamp=detection.timestamp) for detection, plotting_state_vector in plot_clutter.items()], - plotting_label=measurements_label + "
(Clutter)", + plotting_label=label + "
(Clutter)", plotting_keyword_arguments=clutter_kwargs )) - def plot_sensors(self, sensors, sensor_label="Sensors", **kwargs): + def plot_sensors(self, sensors, label="Sensors", **kwargs): raise NotImplementedError @classmethod @@ -2444,7 +2533,7 @@ def _resize(self, data, type="track"): self.fig.update_yaxes(range=[ymin - yrange / 20, ymax + yrange / 20]) - def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", + def plot_ground_truths(self, truths, mapping, label="Ground Truth", resize=True, **kwargs): """Plots ground truth(s) @@ -2463,14 +2552,20 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", for iteration. mapping: list List of items specifying the mapping of the position components of the state space. - truths_label: str + label: str Name of ground truths in legend/plot resize: bool if True, will resize figure to ensure that ground truths are in view \\*\\*kwargs: dict Additional arguments to be passed to plot function. Default is ``linestyle="--"``. + + .. deprecated:: 1.5 + ``label`` has replaced ``truths_label``. In the current implementation + ``truths_label`` overrides ``label``. However, use of ``truths_label`` + may be removed in the future. """ + label = kwargs.pop('truths_label', None) or label if not isinstance(truths, Collection) or isinstance(truths, StateMutableSequence): truths = {truths} # Make a set of length 1 @@ -2498,9 +2593,9 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", # add a trace that keeps the legend up for the entire simulation (will remain # even if no truths are present), then add a trace for each truth in the simulation. # initialise keyword arguments, then add them to the traces - truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=truths_label, + truth_kwargs = dict(x=[], y=[], mode="lines", hoverinfo='none', legendgroup=label, line=dict(dash="dash", color=self.colorway[0]), legendrank=100, - name=truths_label, showlegend=True) + name=label, showlegend=True) merge(truth_kwargs, kwargs) # legend dummy trace self.fig.add_trace(go.Scatter(truth_kwargs)) @@ -2565,7 +2660,7 @@ def plot_ground_truths(self, truths, mapping, truths_label="Ground Truth", self.plotting_function_called = True def plot_measurements(self, measurements, mapping, measurement_model=None, - resize=True, measurements_label="Measurements", + resize=True, label="Measurements", convert_measurements=True, **kwargs): """Plots measurements @@ -2587,7 +2682,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, they cannot be found from the measurements themselves. resize: bool If True, will resize figure to ensure measurements are in view - measurements_label : str + label : str Label for the measurements. Default is "Measurements". convert_measurements : bool Should the measurements be converted from measurement space to state space before @@ -2595,7 +2690,14 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, \\*\\*kwargs: dict Additional arguments to be passed to scatter function for detections. Defaults are ``marker=dict(color="#636EFA")``. + + + .. deprecated:: 1.5 + ``label`` has replaced ``measurements_label``. In the current implementation + ``measurements_label`` overrides ``label``. However, use of ``measurements_label`` + may be removed in the future. """ + label = kwargs.pop('measurements_label', None) or label if not isinstance(measurements, Collection): measurements = {measurements} # Make a set of length 1 @@ -2648,9 +2750,9 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, # initialise detections if plot_clutter: - name = measurements_label + "
(Detections)" + name = label + "
(Detections)" else: - name = measurements_label + name = label measurement_kwargs = dict(x=[], y=[], mode='markers', name=name, legendgroup=name, @@ -2664,7 +2766,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting # change necessary kwargs to initialise clutter trace - name = measurements_label + "
(Clutter)" + name = label + "
(Clutter)" clutter_kwargs = dict(x=[], y=[], mode='markers', name=name, legendgroup=name, @@ -2726,7 +2828,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, particle=False, plot_history=False, ellipse_points=30, - track_label="Tracks", **kwargs): + label="Tracks", **kwargs): """ Plots each track generated, generating a legend automatically. If 'uncertainty=True', error ellipses are plotted. Tracks are plotted as solid lines with point markers @@ -2754,15 +2856,19 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, If true, plots all particles and uncertainty ellipses up to current time step ellipse_points: int Number of points for polygon approximating ellipse shape - track_label: str + label: str Label to apply to all tracks for legend \\*\\*kwargs: dict Additional arguments to be passed to plot function. Defaults are ``linestyle="-"``, ``marker='s'`` for :class:`~.Update` and ``marker='o'`` for other states. - Returns - ------- + + .. deprecated:: 1.5 + ``label`` has replaced ``track_label``. In the current implementation + ``track_label`` overrides ``label``. However, use of ``track_label`` + may be removed in the future. """ + label = kwargs.pop('track_label', None) or label if not isinstance(tracks, Collection) or isinstance(tracks, StateMutableSequence): tracks = {tracks} # Make a set of length 1 @@ -2797,7 +2903,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, # add dummy trace for legend for track track_kwargs = dict(x=[], y=[], mode="markers+lines", line=dict(color=self.colorway[2]), - legendgroup=track_label, legendrank=400, name=track_label, + legendgroup=label, legendrank=400, name=label, showlegend=True) track_kwargs.update(kwargs) self.fig.add_trace(go.Scatter(track_kwargs)) @@ -2869,7 +2975,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, self._resize(data, "tracks") if uncertainty: # plot ellipses - name = f'{track_label}
Uncertainty' + name = f'{label}
Uncertainty' uncertainty_kwargs = dict(x=[], y=[], legendgroup=name, fill='toself', fillcolor=self.colorway[2], opacity=0.2, legendrank=500, name=name, @@ -2894,7 +3000,7 @@ def plot_tracks(self, tracks, mapping, uncertainty=False, resize=True, if particle: # plot particles # initialise traces. One for legend and one per track - name = f'{track_label}
Particles' + name = f'{label}
Particles' particle_kwargs = dict(mode='markers', marker=dict(size=2, color=self.colorway[2]), opacity=0.4, hoverinfo='skip', legendgroup=name, name=name, @@ -2989,7 +3095,7 @@ def _plot_particles_and_ellipses(self, tracks, mapping, resize, method="uncertai if resize: self._resize(data, type="particle_or_uncertainty") - def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): + def plot_sensors(self, sensors, label="Sensors", resize=True, **kwargs): """Plots sensor(s) Plots sensors. Users can change the color and marker of detections using keyword @@ -3000,12 +3106,20 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): ---------- sensors : Collection of :class:`~.Sensor` Sensors to plot - sensor_label: str + label: str Label to apply to all tracks for legend. \\*\\*kwargs: dict Additional arguments to be passed to scatter function for detections. Defaults are ``marker=dict(symbol='x', color='black')``. + + + .. deprecated:: 1.5 + ``label`` has replaced ``sensor_label``. In the current implementation + ``sensor_label`` overrides ``label``. However, use of ``sensor_label`` + may be removed in the future. """ + label = kwargs.pop('sensor_label', None) or label + if not isinstance(sensors, Collection): sensors = {sensors} @@ -3013,8 +3127,8 @@ def plot_sensors(self, sensors, sensor_label="Sensors", resize=True, **kwargs): if sensors: trace_base = len(self.fig.data) # number of traces currently in figure sensor_kwargs = dict(mode='markers', marker=dict(symbol='x', color='black'), - legendgroup=sensor_label, legendrank=50, - name=sensor_label, showlegend=True) + legendgroup=label, legendrank=50, + name=label, showlegend=True) merge(sensor_kwargs, kwargs) self.fig.add_trace(go.Scatter(sensor_kwargs)) # initialises trace From fd71377344deec25d1f3b5d2eb8791e5e9c1017c Mon Sep 17 00:00:00 2001 From: G Webb Date: Thu, 15 Aug 2024 14:27:53 +0100 Subject: [PATCH 4/7] Replaced 'measurements_label', 'track_label', 'sensor_label' and 'truths_label' with 'label' in examples and tutorials. --- .../Comparing_EHM_JPDA_example.py | 6 ++--- .../dataassociation/One_to_One_Associator.py | 22 +++++++++---------- docs/examples/dataassociation/mht_example.py | 2 +- .../eot/extended_object_tracking_example.py | 2 +- .../examples/filters/Multi_Tracker_Example.py | 8 +++---- docs/examples/filters/bearing_only_example.py | 2 +- docs/examples/oosm/KalmanFilterOOSMExample.py | 12 +++++----- docs/examples/oosm/PF_OOSM_example.py | 4 ++-- docs/examples/oosm/example_oosm_algorithm.py | 8 +++---- docs/examples/oosm/example_simple_oosm.py | 6 ++--- docs/examples/plotting/Polar_Plotting.py | 16 +++++++------- .../Example_data_fusion_from_sensors.py | 14 ++++++------ .../trackfusion/Track2Track_Fusion_Example.py | 12 +++++----- .../trackfusion/track_fusion_example.py | 16 +++++++------- docs/tutorials/filters/AKKF.py | 2 +- docs/tutorials/filters/ASDFilter.py | 4 ++-- 16 files changed, 68 insertions(+), 68 deletions(-) diff --git a/docs/examples/dataassociation/Comparing_EHM_JPDA_example.py b/docs/examples/dataassociation/Comparing_EHM_JPDA_example.py index 45ae3fbcf..5b7f4bb5d 100644 --- a/docs/examples/dataassociation/Comparing_EHM_JPDA_example.py +++ b/docs/examples/dataassociation/Comparing_EHM_JPDA_example.py @@ -383,11 +383,11 @@ plotter.plot_ground_truths(groundtruths, [0, 2]) plotter.plot_measurements(detections_set, [0, 2]) plotter.plot_tracks(JPDA_tracks, [0, 2], line= dict(color='orange'), - track_label='JPDA tracks') + label='JPDA tracks') plotter.plot_tracks(EHM1_tracks, [0, 2], line= dict(color='green', dash='dot'), - track_label='EHM1 tracks') + label='EHM1 tracks') plotter.plot_tracks(EHM2_tracks, [0, 2], line= dict(color='red', dash='dot'), - track_label='EHM2 tracks') + label='EHM2 tracks') plotter.fig # %% diff --git a/docs/examples/dataassociation/One_to_One_Associator.py b/docs/examples/dataassociation/One_to_One_Associator.py index 0d8c05808..b91504e57 100644 --- a/docs/examples/dataassociation/One_to_One_Associator.py +++ b/docs/examples/dataassociation/One_to_One_Associator.py @@ -58,11 +58,11 @@ plotter = Plotterly() plotter.plot_tracks(tracks=[Track(state) for state in states_from_a], - mapping=[0, 1], track_label="Source A", + mapping=[0, 1], label="Source A", mode="markers", marker=dict(symbol="cross", color=next(colours_iter))) plotter.plot_tracks(tracks=[Track(state) for state in states_from_b], - mapping=[0, 1], track_label="Source B", + mapping=[0, 1], label="Source B", mode="markers", marker=dict(symbol="circle", color=next(colours_iter))) plotter.fig @@ -109,17 +109,17 @@ track = Track(state_from_a, init_metadata=dict(source="a", association=idx)) plotter.plot_tracks(track, mapping=[0, 1], mode="markers", - track_label=f"{state_names[state_from_a]}, Association {idx}", + label=f"{state_names[state_from_a]}, Association {idx}", marker=dict(symbol="cross", color=colour)) track = Track(state_from_b, init_metadata=dict(source="b", association=idx)) plotter.plot_tracks(track, mapping=[0, 1], mode="markers", - track_label=f"{state_names[state_from_b]}, Association {idx}", + label=f"{state_names[state_from_b]}, Association {idx}", marker=dict(symbol="circle", color=colour)) track = Track([state_from_a, state_from_b], init_metadata=dict(association=idx)) plotter.plot_tracks(track, mapping=[0, 1], mode="lines", - track_label=f"Association {idx}", + label=f"Association {idx}", line=dict(color=colour)) dist_between_states = Euclidean()(state_from_a, state_from_b) @@ -134,7 +134,7 @@ colour = next(colours_iter) track = Track(state, init_metadata=dict(source="a", association=None)) plotter.plot_tracks(track, mapping=[0, 1], - track_label=f"{state_names[state]}, No Association", + label=f"{state_names[state]}, No Association", mode="markers", marker=dict(symbol="cross", color=colour)) for state in unassociated_states_b: @@ -143,7 +143,7 @@ colour = next(colours_iter) track = Track(state, init_metadata=dict(source="b", association=None)) plotter.plot_tracks(track, mapping=[0, 1], - track_label=f"{state_names[state]}, No Association", + label=f"{state_names[state]}, No Association", mode="markers", marker=dict(symbol="circle", color=colour)) # %% @@ -215,11 +215,11 @@ colour = next(colours_iter) for track in tracks_a: - plotter.plot_tracks(track, mapping=[0, 1], track_label=track.id, marker=dict(color=colour)) + plotter.plot_tracks(track, mapping=[0, 1], label=track.id, marker=dict(color=colour)) colour = next(colours_iter) for track in tracks_b: - plotter.plot_tracks(track, mapping=[0, 1], track_label=track.id, marker=dict(color=colour)) + plotter.plot_tracks(track, mapping=[0, 1], label=track.id, marker=dict(color=colour)) plotter.fig @@ -272,7 +272,7 @@ print('Associated together', [track.id for track in assoc.objects]) colour = next(colours_iter) for track in assoc.objects: - plotter.plot_tracks(track, mapping=[0, 1], track_label=track.id, + plotter.plot_tracks(track, mapping=[0, 1], label=track.id, marker=dict(color=colour)) print("Not Associated in A: ", [track.id for track in unassociated_a]) @@ -280,7 +280,7 @@ for track in [*unassociated_a, *unassociated_b]: colour = next(colours_iter) - plotter.plot_tracks(track, mapping=[0, 1], track_label=track.id, marker=dict(color=colour)) + plotter.plot_tracks(track, mapping=[0, 1], label=track.id, marker=dict(color=colour)) plotter.fig diff --git a/docs/examples/dataassociation/mht_example.py b/docs/examples/dataassociation/mht_example.py index 0b056bdb9..ca320f48b 100644 --- a/docs/examples/dataassociation/mht_example.py +++ b/docs/examples/dataassociation/mht_example.py @@ -238,7 +238,7 @@ tracks.add(track) -plotter.plot_tracks(tracks, [0, 2], track_label="Tracks", line=dict(color="Green")) +plotter.plot_tracks(tracks, [0, 2], label="Tracks", line=dict(color="Green")) plotter.fig # %% diff --git a/docs/examples/eot/extended_object_tracking_example.py b/docs/examples/eot/extended_object_tracking_example.py index 02b463f4f..52153b8e5 100644 --- a/docs/examples/eot/extended_object_tracking_example.py +++ b/docs/examples/eot/extended_object_tracking_example.py @@ -349,7 +349,7 @@ tracks.update(current_tracks) plotter.plot_measurements(centroid_detections, [0, 2], marker=dict(color='red'), - measurements_label='Cluster centroids') + label='Cluster centroids') plotter.plot_tracks(tracks, [0, 2]) plotter.fig diff --git a/docs/examples/filters/Multi_Tracker_Example.py b/docs/examples/filters/Multi_Tracker_Example.py index 9c1eca2ed..6b7ffbcf1 100644 --- a/docs/examples/filters/Multi_Tracker_Example.py +++ b/docs/examples/filters/Multi_Tracker_Example.py @@ -365,13 +365,13 @@ # %% # Finally, we plot the results: -plotter.plot_tracks(tracks_EKF, [0, 2], track_label="EKF", line=dict(color="orange"), +plotter.plot_tracks(tracks_EKF, [0, 2], label="EKF", line=dict(color="orange"), uncertainty=False) -plotter.plot_tracks(tracks_UKF, [0, 2], track_label="UKF", line=dict(color="blue"), +plotter.plot_tracks(tracks_UKF, [0, 2], label="UKF", line=dict(color="blue"), uncertainty=False) -plotter.plot_tracks(tracks_PF, [0, 2], track_label="PF", line=dict(color="brown"), +plotter.plot_tracks(tracks_PF, [0, 2], label="PF", line=dict(color="brown"), uncertainty=False) -plotter.plot_tracks(tracks_ESIF, [0, 2], track_label="ESIF", line=dict(color="green"), +plotter.plot_tracks(tracks_ESIF, [0, 2], label="ESIF", line=dict(color="green"), uncertainty=False) plotter.fig diff --git a/docs/examples/filters/bearing_only_example.py b/docs/examples/filters/bearing_only_example.py index 841687693..a9bce17a6 100644 --- a/docs/examples/filters/bearing_only_example.py +++ b/docs/examples/filters/bearing_only_example.py @@ -230,7 +230,7 @@ plotter = AnimationPlotter(legend_kwargs=dict(loc='upper left')) plotter.plot_ground_truths(groundtruth_paths, (0,2)) plotter.plot_tracks(kalman_tracks, (0,2)) -plotter.plot_ground_truths(platform, (0,2), truths_label="Sensor Platform") +plotter.plot_ground_truths(platform, (0,2), label="Sensor Platform") plotter.run() # %% diff --git a/docs/examples/oosm/KalmanFilterOOSMExample.py b/docs/examples/oosm/KalmanFilterOOSMExample.py index d6baf5017..3b567b4c6 100644 --- a/docs/examples/oosm/KalmanFilterOOSMExample.py +++ b/docs/examples/oosm/KalmanFilterOOSMExample.py @@ -155,12 +155,12 @@ plotter = AnimatedPlotterly(timesteps=time_steps) plotter.plot_ground_truths(truth, [0, 2]) plotter.plot_measurements(measurements1, [0, 2], marker=dict(color='blue'), - measurements_label='Detections with no lag') + label='Detections with no lag') plotter.plot_measurements(measurements2, [0, 2], marker=dict(color='orange'), - measurements_label='Detections with lag') + label='Detections with lag') plotter.plot_sensors([sensor1_platform, sensor2_platform], marker=dict(color='black', symbol='129', size=15), - sensor_label='Fixed Platforms') + label='Fixed Platforms') plotter.fig # %% @@ -297,10 +297,10 @@ # However, it is interesting to see a 1-to-1 comparison between the three trackers, even if the # Tracker 2 track is not, visually, lagging behind. -plotter.plot_tracks(track1, [0, 2], track_label='Tracker 1') -plotter.plot_tracks(track2, [0, 2], track_label='Tracker 2', +plotter.plot_tracks(track1, [0, 2], label='Tracker 1') +plotter.plot_tracks(track2, [0, 2], label='Tracker 2', line=dict(color='red')) -plotter.plot_tracks(track3, [0, 2], track_label='Tracker 3', +plotter.plot_tracks(track3, [0, 2], label='Tracker 3', line=dict(color='green')) plotter.fig diff --git a/docs/examples/oosm/PF_OOSM_example.py b/docs/examples/oosm/PF_OOSM_example.py index 4d7e551c5..edf4e786e 100644 --- a/docs/examples/oosm/PF_OOSM_example.py +++ b/docs/examples/oosm/PF_OOSM_example.py @@ -302,9 +302,9 @@ plotter.plot_ground_truths(truths, [0, 2]) plotter.plot_measurements(scans_detections, [0, 2]) -plotter.plot_tracks(track, [0, 2], track_label='Track dealing with OOSM', +plotter.plot_tracks(track, [0, 2], label='Track dealing with OOSM', line=dict(color='blue')) -plotter.plot_tracks(track2, [0, 2], track_label='Track ignoring OOSM') +plotter.plot_tracks(track2, [0, 2], label='Track ignoring OOSM') plotter.fig # %% diff --git a/docs/examples/oosm/example_oosm_algorithm.py b/docs/examples/oosm/example_oosm_algorithm.py index 026c5fc29..28cdb919d 100644 --- a/docs/examples/oosm/example_oosm_algorithm.py +++ b/docs/examples/oosm/example_oosm_algorithm.py @@ -339,11 +339,11 @@ from stonesoup.plotter import AnimatedPlotterly plotter = AnimatedPlotterly(timesteps=timestamps) plotter.plot_ground_truths(truths, [0, 2]) -plotter.plot_measurements(scan_s1, [0, 2], measurements_label='scan1', measurement_model=sensor_1_mm) -plotter.plot_measurements(scan_s2, [0, 2], measurements_label='scan2', measurement_model=sensor_1_mm) -plotter.plot_tracks(oosm_tracks, [0, 2], track_label='OOSM Tracks', +plotter.plot_measurements(scan_s1, [0, 2], label='scan1', measurement_model=sensor_1_mm) +plotter.plot_measurements(scan_s2, [0, 2], label='scan2', measurement_model=sensor_1_mm) +plotter.plot_tracks(oosm_tracks, [0, 2], label='OOSM Tracks', line= dict(color='orange')) -plotter.plot_tracks(noOsm_tracks, [0, 2], track_label='no-OOSM Tracks', +plotter.plot_tracks(noOsm_tracks, [0, 2], label='no-OOSM Tracks', line= dict(color='red')) plotter.fig diff --git a/docs/examples/oosm/example_simple_oosm.py b/docs/examples/oosm/example_simple_oosm.py index 46b25720d..4ca05548b 100644 --- a/docs/examples/oosm/example_simple_oosm.py +++ b/docs/examples/oosm/example_simple_oosm.py @@ -201,9 +201,9 @@ plotter = AnimatedPlotterly(timesteps=timestamps) plotter.plot_ground_truths(truths, [0, 2]) -plotter.plot_measurements(scans_detections, [0, 2], measurements_label='Detections', +plotter.plot_measurements(scans_detections, [0, 2], label='Detections', measurement_model=measurement_model) -plotter.plot_tracks(track_lag, [0, 2], line= dict(color='grey'), track_label='Track with lag') +plotter.plot_tracks(track_lag, [0, 2], line= dict(color='grey'), label='Track with lag') plotter.fig # %% @@ -262,7 +262,7 @@ # Plotting the final track # ------------------------ -plotter.plot_tracks(track, [0, 2], line= dict(color='blue'), track_label='Track with OOSM treated') +plotter.plot_tracks(track, [0, 2], line= dict(color='blue'), label='Track with OOSM treated') plotter.fig # %% diff --git a/docs/examples/plotting/Polar_Plotting.py b/docs/examples/plotting/Polar_Plotting.py index d6ada97e2..bc0d6d222 100644 --- a/docs/examples/plotting/Polar_Plotting.py +++ b/docs/examples/plotting/Polar_Plotting.py @@ -85,8 +85,8 @@ # :class:`~.Plotterly` plotting class: plotter_xy = Plotterly(title="Bird's Eye View of Targets") mapping = [0, 2] -plotter_xy.plot_ground_truths(target_1, mapping=[0, 2], truths_label="Target 1") -plotter_xy.plot_ground_truths(target_2, mapping=[0, 2], truths_label="Target 2") +plotter_xy.plot_ground_truths(target_1, mapping=[0, 2], label="Target 1") +plotter_xy.plot_ground_truths(target_2, mapping=[0, 2], label="Target 2") plotter_xy.fig # %% @@ -130,9 +130,9 @@ yaxis=dict(title=dict(text="Bearing (Radians)")) ) plotter_az_t_cart.plot_ground_truths({angular_ground_truth_1}, - mapping=mapping, truths_label="Target 1") + mapping=mapping, label="Target 1") plotter_az_t_cart.plot_ground_truths({angular_ground_truth_2}, - mapping=mapping, truths_label="Target 2") + mapping=mapping, label="Target 2") plotter_az_t_cart.plot_measurements(detections, mapping=mapping, convert_measurements=False) plotter_az_t_cart.fig @@ -154,8 +154,8 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ mapping = [0] plotter_az_t = PolarPlotterly(title="Azimuth Angle (Degrees) vs Time (s)") -plotter_az_t.plot_ground_truths({angular_ground_truth_1}, mapping=mapping, truths_label="Target 1") -plotter_az_t.plot_ground_truths({angular_ground_truth_2}, mapping=mapping, truths_label="Target 2") +plotter_az_t.plot_ground_truths({angular_ground_truth_1}, mapping=mapping, label="Target 1") +plotter_az_t.plot_ground_truths({angular_ground_truth_2}, mapping=mapping, label="Target 2") plotter_az_t.plot_measurements(detections, mapping=mapping, convert_measurements=False) plotter_az_t.fig @@ -175,8 +175,8 @@ # sphinx_gallery_thumbnail_number = 4 plotter_az_r = PolarPlotterly(title="Azimuth Angle (Degrees) vs Range (m)") mapping = [0, 1] -plotter_az_r.plot_ground_truths({angular_ground_truth_1}, mapping=mapping, truths_label="Target 1") -plotter_az_r.plot_ground_truths({angular_ground_truth_2}, mapping=mapping, truths_label="Target 2") +plotter_az_r.plot_ground_truths({angular_ground_truth_1}, mapping=mapping, label="Target 1") +plotter_az_r.plot_ground_truths({angular_ground_truth_2}, mapping=mapping, label="Target 2") plotter_az_r.plot_measurements(detections, mapping=mapping, convert_measurements=False) plotter_az_r.fig diff --git a/docs/examples/sensorfusion/Example_data_fusion_from_sensors.py b/docs/examples/sensorfusion/Example_data_fusion_from_sensors.py index 09c4d5745..3fce8f0ed 100644 --- a/docs/examples/sensorfusion/Example_data_fusion_from_sensors.py +++ b/docs/examples/sensorfusion/Example_data_fusion_from_sensors.py @@ -459,19 +459,19 @@ plotter = Plotterly() plotter.plot_measurements(s1_detections, [0, 2], - measurements_label='Radar 1 measurements'), + label='Radar 1 measurements'), plotter.plot_measurements(s2_detections, [0, 2], - measurements_label='Radar 2 measurements') -plotter.plot_tracks(ukf_tracks, [0, 2], line=dict(color='green'), track_label='UKF tracks') -plotter.plot_tracks(ekf_tracks, [0, 2], line=dict(color='blue'), track_label='EKF tracks') + label='Radar 2 measurements') +plotter.plot_tracks(ukf_tracks, [0, 2], line=dict(color='green'), label='UKF tracks') +plotter.plot_tracks(ekf_tracks, [0, 2], line=dict(color='blue'), label='EKF tracks') plotter.plot_tracks(pf_tracks, [0, 2], particle=False, line=dict(color='red'), - track_label='PF tracks') + label='PF tracks') plotter.plot_ground_truths(truths, [0, 2]) plotter.plot_sensors(sensor1_platform, [0, 1], marker=dict(color='black', symbol='129', size=15), - sensor_label='Fixed Platform') + label='Fixed Platform') plotter.plot_ground_truths(sensor2_platform, [0, 2], marker=dict(color='orange', symbol='cross', size=25), - truths_label='Moving Platform') + label='Moving Platform') plotter.fig # %% diff --git a/docs/examples/trackfusion/Track2Track_Fusion_Example.py b/docs/examples/trackfusion/Track2Track_Fusion_Example.py index 033a0598a..112a901dc 100644 --- a/docs/examples/trackfusion/Track2Track_Fusion_Example.py +++ b/docs/examples/trackfusion/Track2Track_Fusion_Example.py @@ -629,17 +629,17 @@ plotter.plot_ground_truths(truths, [0, 2], color='black') plotter.plot_measurements(s1_detections, [0, 2], color='orange', marker='*', - measurements_label='Measurements - Airborne Radar') + label='Measurements - Airborne Radar') plotter.plot_measurements(s2_detections, [0, 2], color='blue', marker='*', - measurements_label='Measurements - Ground Radar') + label='Measurements - Ground Radar') plotter.plot_tracks(jpda_tracks, [0, 2], color='red', - track_label='Tracks - Airborne Radar (JPDAF)') + label='Tracks - Airborne Radar (JPDAF)') plotter.plot_tracks(gmlcc_tracks, [0, 2], color='purple', - track_label='Tracks - Ground Radar (GM-LCC)') + label='Tracks - Ground Radar (GM-LCC)') plotter.plot_tracks(meas_fusion_tracks, [0, 2], color='green', - track_label='Tracks - Measurement Fusion (GM-PHD)') + label='Tracks - Measurement Fusion (GM-PHD)') plotter.plot_tracks(track_fusion_tracks, [0, 2], color='pink', - track_label='Tracks - Covariance Intersection (GM-PHD)') + label='Tracks - Covariance Intersection (GM-PHD)') # Format the legend a bit. Set the position outside of the plot, and # swap the order of the clutter and ground radar measurements diff --git a/docs/examples/trackfusion/track_fusion_example.py b/docs/examples/trackfusion/track_fusion_example.py index be952a94b..ed2930661 100644 --- a/docs/examples/trackfusion/track_fusion_example.py +++ b/docs/examples/trackfusion/track_fusion_example.py @@ -185,9 +185,9 @@ # Plot the detections from the two radars plotter = Plotterly() plotter.plot_measurements(s1_detections, [0, 2], marker=dict(color='red'), - measurements_label='Sensor 1 measurements') + label='Sensor 1 measurements') plotter.plot_measurements(s2_detections, [0, 2], marker=dict(color='blue'), - measurements_label='Sensor 2 measurements') + label='Sensor 2 measurements') plotter.plot_sensors({sensor1_platform, sensor2_platform}, [0, 1], marker=dict(color='black', symbol='1', size=10)) plotter.plot_ground_truths(truths, [0, 2]) @@ -431,12 +431,12 @@ # Let's visualise the various tracks and detections in the cases # using the Kalman and particle filters. -plotter.plot_tracks(PF_track1, [0, 2], line=dict(color="orange"), track_label='PF partial track 1') -plotter.plot_tracks(PF_track2, [0, 2], line=dict(color="gold"), track_label='PF partial track 2') -plotter.plot_tracks(PF_fused_track, [0, 2], line=dict(color="red"), track_label='PF fused track') -plotter.plot_tracks(KF_fused_track, [0, 2], line=dict(color="blue"), track_label='KF fused track') -plotter.plot_tracks(KF_track1, [0, 2], line=dict(color="cyan"), track_label='KF partial track 1') -plotter.plot_tracks(KF_track2, [0, 2], line=dict(color="skyblue"), track_label='KF partial track 2') +plotter.plot_tracks(PF_track1, [0, 2], line=dict(color="orange"), label='PF partial track 1') +plotter.plot_tracks(PF_track2, [0, 2], line=dict(color="gold"), label='PF partial track 2') +plotter.plot_tracks(PF_fused_track, [0, 2], line=dict(color="red"), label='PF fused track') +plotter.plot_tracks(KF_fused_track, [0, 2], line=dict(color="blue"), label='KF fused track') +plotter.plot_tracks(KF_track1, [0, 2], line=dict(color="cyan"), label='KF partial track 1') +plotter.plot_tracks(KF_track2, [0, 2], line=dict(color="skyblue"), label='KF partial track 2') plotter.fig diff --git a/docs/tutorials/filters/AKKF.py b/docs/tutorials/filters/AKKF.py index 404554f85..8b95479d4 100644 --- a/docs/tutorials/filters/AKKF.py +++ b/docs/tutorials/filters/AKKF.py @@ -429,7 +429,7 @@ plotter = Plotter() plotter.plot_ground_truths(truth, [0, 2], linewidth=3.0, color='black') -plotter.plot_tracks(track, [0, 2], track_label='AKKF - quadratic', color='royalblue') +plotter.plot_tracks(track, [0, 2], label='AKKF - quadratic', color='royalblue') plotter.fig # %% diff --git a/docs/tutorials/filters/ASDFilter.py b/docs/tutorials/filters/ASDFilter.py index 6bbcfd221..96d756ecd 100644 --- a/docs/tutorials/filters/ASDFilter.py +++ b/docs/tutorials/filters/ASDFilter.py @@ -189,9 +189,9 @@ asd_states = sorted(asd_states, key=attrgetter('timestamp')) plotter.plot_tracks({track2}, [0, 2], uncertainty=True, line=dict(color='green'), - track_label="Equivalent track without ASD") + label="Equivalent track without ASD") plotter.plot_tracks({Track(asd_states)}, [0, 2], line=dict(color='red'), - track_label="ASD Track") + label="ASD Track") plotter.fig # %% From 75d743a8591ed555cdf09bd890d2f334470d8656 Mon Sep 17 00:00:00 2001 From: G Webb Date: Fri, 16 Aug 2024 16:36:52 +0100 Subject: [PATCH 5/7] Added parametrized tests in stonesoup/tests/test_plotter.py Changed measurement labels in stonesoup/plotter.py for matplotlib plotters --- stonesoup/plotter.py | 60 ++++----- stonesoup/tests/test_plotter.py | 215 +++++++++++++++++++++++--------- 2 files changed, 189 insertions(+), 86 deletions(-) diff --git a/stonesoup/plotter.py b/stonesoup/plotter.py index ef3d8fab4..f994519e2 100644 --- a/stonesoup/plotter.py +++ b/stonesoup/plotter.py @@ -298,7 +298,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, # Generate legend items for measurements if plot_clutter: - name = label + "
(Detections)" + name = label + "\n(Detections)" else: name = label self.legend_dict[name] = measurements_handle @@ -311,7 +311,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, clutter_handle = Line2D([], [], linestyle='', **clutter_kwargs) # Generate legend items for clutter - name = label + "
(Clutter)" + name = label + "\n(Clutter)" self.legend_dict[name] = clutter_handle # Generate legend @@ -2103,7 +2103,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, if plot_detections: if plot_clutter: - name = label + "
(Detections)" + name = label + "\n(Detections)" else: name = label detection_kwargs = dict(linestyle='', marker='o', color='b') @@ -2123,7 +2123,7 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, plotting_data=[State(state_vector=plotting_state_vector, timestamp=detection.timestamp) for detection, plotting_state_vector in plot_clutter.items()], - plotting_label=label + "
(Clutter)", + plotting_label=label + "\n(Clutter)", plotting_keyword_arguments=clutter_kwargs )) @@ -2748,34 +2748,36 @@ def plot_measurements(self, measurements, mapping, measurement_model=None, # get number of traces currently in fig trace_base = len(self.fig.data) - # initialise detections + if plot_detections: + # initialise detections + if plot_clutter: + name = label + "
(Detections)" + else: + name = label + measurement_kwargs = dict(x=[], y=[], mode='markers', + name=name, + legendgroup=name, + legendrank=200, showlegend=True, + marker=dict(color="#636EFA"), hoverinfo='none') + merge(measurement_kwargs, kwargs) + + self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend + + measurement_kwargs.update({"showlegend": False}) + self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting + if plot_clutter: - name = label + "
(Detections)" - else: - name = label - measurement_kwargs = dict(x=[], y=[], mode='markers', + # change necessary kwargs to initialise clutter trace + name = label + "
(Clutter)" + clutter_kwargs = dict(x=[], y=[], mode='markers', name=name, legendgroup=name, - legendrank=200, showlegend=True, - marker=dict(color="#636EFA"), hoverinfo='none') - merge(measurement_kwargs, kwargs) - - self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for legend - - measurement_kwargs.update({"showlegend": False}) - self.fig.add_trace(go.Scatter(measurement_kwargs)) # trace for plotting - - # change necessary kwargs to initialise clutter trace - name = label + "
(Clutter)" - clutter_kwargs = dict(x=[], y=[], mode='markers', - name=name, - legendgroup=name, - legendrank=300, showlegend=True, - marker=dict(symbol="star-triangle-up", color='#FECB52'), - hoverinfo='none') - merge(clutter_kwargs, kwargs) - - self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter + legendrank=300, showlegend=True, + marker=dict(symbol="star-triangle-up", color='#FECB52'), + hoverinfo='none') + merge(clutter_kwargs, kwargs) + + self.fig.add_trace(go.Scatter(clutter_kwargs)) # trace for plotting clutter # add data to frames for frame in self.fig.frames: diff --git a/stonesoup/tests/test_plotter.py b/stonesoup/tests/test_plotter.py index eb0ec2fe3..9a23b4f0f 100644 --- a/stonesoup/tests/test_plotter.py +++ b/stonesoup/tests/test_plotter.py @@ -1,31 +1,26 @@ +from datetime import datetime, timedelta + +import matplotlib.pyplot as plt import numpy as np -from stonesoup.plotter import Plotter, Dimension, AnimatedPlotterly, AnimationPlotter, Plotterly import pytest -import matplotlib.pyplot as plt - -# Setup simulation to test the plotter functionality -from datetime import datetime -from datetime import timedelta -from stonesoup.types.detection import TrueDetection +from stonesoup.dataassociator.neighbour import NearestNeighbour +from stonesoup.hypothesiser.distance import DistanceHypothesiser +from stonesoup.measures import Mahalanobis from stonesoup.models.measurement.linear import LinearGaussian -from stonesoup.sensor.radar.radar import RadarElevationBearingRange - from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \ - ConstantVelocity -from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState - + ConstantVelocity +from stonesoup.plotter import Plotter, Dimension, AnimatedPlotterly, AnimationPlotter, Plotterly, \ + PolarPlotterly from stonesoup.predictor.kalman import KalmanPredictor -from stonesoup.updater.kalman import KalmanUpdater - -from stonesoup.hypothesiser.distance import DistanceHypothesiser -from stonesoup.measures import Mahalanobis - -from stonesoup.dataassociator.neighbour import NearestNeighbour +from stonesoup.sensor.radar.radar import RadarElevationBearingRange +from stonesoup.types.detection import TrueDetection, Clutter +from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState from stonesoup.types.state import GaussianState, State - from stonesoup.types.track import Track +from stonesoup.updater.kalman import KalmanUpdater +# Setup simulation to test the plotter functionality start_time = datetime.now() transition_model = CombinedLinearGaussianTransitionModel([ConstantVelocity(0.005), ConstantVelocity(0.005)]) @@ -42,7 +37,7 @@ mapping=(0, 2), noise_covar=np.array([[0.75, 0], [0, 0.75]])) -all_measurements = [] +true_measurements = [] for state in truth: measurement_set = set() # Generate actual detection from the state with a 1-p_d chance that no detection is received. @@ -53,7 +48,26 @@ timestamp=state.timestamp, measurement_model=measurement_model)) - all_measurements.append(measurement_set) + true_measurements.append(measurement_set) + +prob_clutter = 0.8 +clutter_measurements = [] +for state in truth: + clutter_measurement_set = set() + # Generate clutter detections + if np.random.rand() <= prob_clutter: + random_state = state.from_state( + state=state, + state_vector=np.random.uniform(-20, 20, size=state.state_vector.size) + ) + measurement = measurement_model.function(random_state, noise=True) + clutter_measurement_set.add(Clutter(state_vector=measurement, + timestamp=state.timestamp, + measurement_model=measurement_model)) + + clutter_measurements.append(clutter_measurement_set) + +all_measurements = [*true_measurements, *clutter_measurements] predictor = KalmanPredictor(transition_model) updater = KalmanUpdater(measurement_model) @@ -64,7 +78,7 @@ # Create prior prior = GaussianState([[0], [1], [0], [1]], np.diag([1.5, 0.5, 1.5, 0.5]), timestamp=start_time) track = Track([prior]) -for n, measurements in enumerate(all_measurements): +for n, measurements in enumerate(true_measurements): hypotheses = data_associator.associate([track], measurements, start_time + timedelta(seconds=n)) @@ -91,32 +105,13 @@ position=np.array([[10], [50], [0]]) ) -plotter = Plotter() -# Test functions - +# Test functions def test_dimension_inlist(): # ensure dimension type is in predefined enum list with pytest.raises(AttributeError): Plotter(dimension=Dimension.TESTERROR) -def test_measurements_legend(): - plotter.plot_measurements(all_measurements, [0, 2]) # Measurements entry in legend dict - plt.close() - assert 'Measurements' in plotter.legend_dict - - -def test_measurement_clutter(): # no clutter should be plotted - plotter.plot_measurements(all_measurements, [0, 2]) - plt.close() - assert 'Clutter' not in plotter.legend_dict - - -def test_single_measurement(): # A single measurement outside of a Collection should still run - plotter.plot_measurements(all_measurements[0], [0, 2]) - plt.close() - - def test_particle_3d(): # warning should arise if particle is attempted in 3d mode plotter3 = Plotter(dimension=Dimension.THREE) @@ -131,9 +126,17 @@ def test_plot_sensors(): assert 'Sensors' in plotter3d.legend_dict -def test_empty_tracks(): +def create_animated_plotterly(): + """Generates a AnimatedPlotterly object. Used for parameterized testing.""" + return AnimatedPlotterly(timesteps) + + +@pytest.mark.parametrize( + "plotter_class", + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) +def test_empty_tracks(plotter_class): + plotter = plotter_class() plotter.plot_tracks(set(), [0, 2]) - plt.close() def test_figsize(): @@ -217,7 +220,7 @@ def test_plot_complex_uncertainty(): def test_animation_plotter(): animation_plotter = AnimationPlotter() animation_plotter.plot_ground_truths(truth, [0, 2]) - animation_plotter.plot_measurements(all_measurements, [0, 2]) + animation_plotter.plot_measurements(true_measurements, [0, 2]) animation_plotter.run() animation_plotter_with_title = AnimationPlotter(title="Plot title") @@ -229,7 +232,7 @@ def test_animation_plotter(): def test_animated_plotterly(): plotter = AnimatedPlotterly(timesteps) plotter.plot_ground_truths(truth, [0, 2]) - plotter.plot_measurements(all_measurements, [0, 2]) + plotter.plot_measurements(true_measurements, [0, 2]) plotter.plot_tracks(track, [0, 2], uncertainty=True, plot_history=True) @@ -256,19 +259,19 @@ def test_animated_plotterly_uneven_times(): def test_plotterly_empty(): plotter = Plotterly() - plotter.plot_ground_truths({}, [0, 2]) - plotter.plot_measurements({}, [0, 2]) - plotter.plot_tracks({}, [0, 2]) + plotter.plot_ground_truths(set(), [0, 2]) + plotter.plot_measurements(set(), [0, 2]) + plotter.plot_tracks(set(), [0, 2]) with pytest.raises(TypeError): - plotter.plot_tracks({}) + plotter.plot_tracks(set()) with pytest.raises(ValueError): - plotter.plot_tracks({}, []) + plotter.plot_tracks(set(), []) def test_plotterly_1d(): plotter1d = Plotterly(dimension=1) plotter1d.plot_ground_truths(truth, [0]) - plotter1d.plot_measurements(all_measurements, [0]) + plotter1d.plot_measurements(true_measurements, [0]) plotter1d.plot_tracks(track, [0]) # check that particle=True does not plot @@ -283,7 +286,7 @@ def test_plotterly_1d(): def test_plotterly_2d(): plotter2d = Plotterly() plotter2d.plot_ground_truths(truth, [0, 2]) - plotter2d.plot_measurements(all_measurements, [0, 2]) + plotter2d.plot_measurements(true_measurements, [0, 2]) plotter2d.plot_tracks(track, [0, 2], uncertainty=True) plotter2d.plot_sensors(sensor2d) @@ -291,7 +294,7 @@ def test_plotterly_2d(): def test_plotterly_3d(): plotter3d = Plotterly(dimension=3) plotter3d.plot_ground_truths(truth, [0, 1, 2]) - plotter3d.plot_measurements(all_measurements, [0, 1, 2]) + plotter3d.plot_measurements(true_measurements, [0, 1, 2]) plotter3d.plot_tracks(track, [0, 1, 2], uncertainty=True) with pytest.raises(NotImplementedError): @@ -313,7 +316,7 @@ def test_plotterly_wrong_dimension(dim, mapping): plotter.plot_ground_truths(truth, mapping) with pytest.raises(TypeError): - plotter.plot_measurements(all_measurements, mapping) + plotter.plot_measurements(true_measurements, mapping) with pytest.raises(TypeError): plotter.plot_tracks(track, mapping) @@ -325,7 +328,7 @@ def test_plotterly_wrong_dimension(dim, mapping): def test_hide_plot(labels): plotter = Plotterly() plotter.plot_ground_truths(truth, [0, 1]) - plotter.plot_measurements(all_measurements, [0, 1]) + plotter.plot_measurements(true_measurements, [0, 1]) plotter.plot_tracks(track, [0, 1]) plotter.hide_plot_traces(labels) @@ -352,7 +355,7 @@ def test_hide_plot(labels): def test_show_plot(labels): plotter = Plotterly() plotter.plot_ground_truths(truth, [0, 1]) - plotter.plot_measurements(all_measurements, [0, 1]) + plotter.plot_measurements(true_measurements, [0, 1]) plotter.plot_tracks(track, [0, 1]) plotter.show_plot_traces(labels) @@ -371,3 +374,101 @@ def test_show_plot(labels): else: assert showing == len(labels) assert showing + hidden == 3 + + +@pytest.mark.parametrize( + "plotter_class", + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) +@pytest.mark.parametrize( + "_measurements", + [true_measurements, clutter_measurements, all_measurements, + all_measurements[0] # Tests a single measurement outside of a Collection should still run + ]) +def test_plotters_plot_measurements_2d(plotter_class, _measurements): + plotter = plotter_class() + plotter.plot_measurements(_measurements, [0, 2]) + + +@pytest.mark.parametrize( + "plotter_class", + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) +def test_plotters_plot_tracks(plotter_class): + plotter = plotter_class() + plotter.plot_tracks(track, [0, 2]) + + +@pytest.mark.parametrize( + "plotter_class", + [Plotter, + Plotterly, + pytest.param(AnimationPlotter, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(PolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError)), + create_animated_plotterly] +) +def test_plotters_plot_track_uncertainty(plotter_class): + plotter = plotter_class() + plotter.plot_tracks(track, [0, 2], uncertainty=True) + + +@pytest.mark.xfail(raises=NotImplementedError) +@pytest.mark.parametrize( + "plotter_class", + [AnimationPlotter, + PolarPlotterly] +) +def test_plotters_plot_track_particle(plotter_class): + plotter = plotter_class() + plotter.plot_tracks(track, [0, 2], particle=True) + + +@pytest.mark.parametrize( + "plotter_class", + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) +def test_plotters_plot_truths(plotter_class): + plotter = plotter_class() + plotter.plot_ground_truths(truth, [0, 2]) + + +@pytest.mark.parametrize( + "plotter_class", + [Plotter, + Plotterly, + pytest.param(AnimationPlotter, marks=pytest.mark.xfail(raises=NotImplementedError)), + pytest.param(PolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError)), + create_animated_plotterly] +) +def test_plotters_plot_sensors(plotter_class): + plotter = plotter_class() + plotter.plot_sensors(sensor2d) + + +@pytest.mark.parametrize("plotter_class", + [Plotterly, PolarPlotterly, create_animated_plotterly]) +@pytest.mark.parametrize("_measurements, expected_labels", + [(true_measurements, {'Measurements'}), + (clutter_measurements, {'Measurements
(Clutter)'}), + (all_measurements, {'Measurements
(Detections)', + 'Measurements
(Clutter)'}) + ]) +def test_plotterlys_plot_measurements_label(plotter_class, _measurements, expected_labels): + plotter = plotter_class() + plotter.plot_measurements(_measurements, [0, 2]) + actual_labels = {fig_data.legendgroup for fig_data in plotter.fig.data} + assert actual_labels == expected_labels + + +@pytest.mark.parametrize("_measurements, expected_labels", + [(true_measurements, {'Measurements'}), + (clutter_measurements, {'Measurements\n(Clutter)'}), + (all_measurements, {'Measurements\n(Detections)', + 'Measurements\n(Clutter)'}) + ]) +def test_plotter_plot_measurements_label(_measurements, expected_labels): + plotter = Plotter() + plotter.plot_measurements(_measurements, [0, 2]) + actual_labels = set(plotter.legend_dict.keys()) + assert actual_labels == expected_labels + + +def test_close_all_figures(): + plt.close('all') From d2aba390d1b1214594873247237b4a0eb6f217cd Mon Sep 17 00:00:00 2001 From: G Webb Date: Mon, 19 Aug 2024 13:26:39 +0100 Subject: [PATCH 6/7] working progress on closing figures after using matplotlib --- stonesoup/tests/test_plotter.py | 60 ++++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/stonesoup/tests/test_plotter.py b/stonesoup/tests/test_plotter.py index 9a23b4f0f..20d37686a 100644 --- a/stonesoup/tests/test_plotter.py +++ b/stonesoup/tests/test_plotter.py @@ -106,6 +106,37 @@ ) +@pytest.fixture(scope="function") +def plotter_class(request): + + plotter_class = request.param + assert plotter_class in {Plotter, Plotterly, AnimationPlotter, + PolarPlotterly, AnimatedPlotterly} + + figures_to_close = [] + + def _generate_animated_plotterly(*args, **kwargs): + return AnimatedPlotterly(*args, timesteps=timesteps, **kwargs) + + def _generate_plotter(*args, **kwargs): + _plotter = Plotter(*args, **kwargs) + figures_to_close.append(_plotter.fig) + return _plotter + + def _generate_other_plotter(*args, **kwargs): + return plotter_class(*args, **kwargs) + + if plotter_class is Plotter: + yield _generate_plotter + elif plotter_class is AnimatedPlotterly: + yield _generate_animated_plotterly + else: + yield _generate_other_plotter + + for fig in figures_to_close: + plt.close(fig) + + # Test functions def test_dimension_inlist(): # ensure dimension type is in predefined enum list with pytest.raises(AttributeError): @@ -126,14 +157,9 @@ def test_plot_sensors(): assert 'Sensors' in plotter3d.legend_dict -def create_animated_plotterly(): - """Generates a AnimatedPlotterly object. Used for parameterized testing.""" - return AnimatedPlotterly(timesteps) - - @pytest.mark.parametrize( "plotter_class", - [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True) def test_empty_tracks(plotter_class): plotter = plotter_class() plotter.plot_tracks(set(), [0, 2]) @@ -378,7 +404,7 @@ def test_show_plot(labels): @pytest.mark.parametrize( "plotter_class", - [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True) @pytest.mark.parametrize( "_measurements", [true_measurements, clutter_measurements, all_measurements, @@ -391,7 +417,7 @@ def test_plotters_plot_measurements_2d(plotter_class, _measurements): @pytest.mark.parametrize( "plotter_class", - [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True) def test_plotters_plot_tracks(plotter_class): plotter = plotter_class() plotter.plot_tracks(track, [0, 2]) @@ -403,7 +429,8 @@ def test_plotters_plot_tracks(plotter_class): Plotterly, pytest.param(AnimationPlotter, marks=pytest.mark.xfail(raises=NotImplementedError)), pytest.param(PolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError)), - create_animated_plotterly] + AnimatedPlotterly], + indirect=True ) def test_plotters_plot_track_uncertainty(plotter_class): plotter = plotter_class() @@ -423,7 +450,7 @@ def test_plotters_plot_track_particle(plotter_class): @pytest.mark.parametrize( "plotter_class", - [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, create_animated_plotterly]) + [Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly], indirect=True) def test_plotters_plot_truths(plotter_class): plotter = plotter_class() plotter.plot_ground_truths(truth, [0, 2]) @@ -435,7 +462,7 @@ def test_plotters_plot_truths(plotter_class): Plotterly, pytest.param(AnimationPlotter, marks=pytest.mark.xfail(raises=NotImplementedError)), pytest.param(PolarPlotterly, marks=pytest.mark.xfail(raises=NotImplementedError)), - create_animated_plotterly] + AnimatedPlotterly], indirect=True ) def test_plotters_plot_sensors(plotter_class): plotter = plotter_class() @@ -443,7 +470,7 @@ def test_plotters_plot_sensors(plotter_class): @pytest.mark.parametrize("plotter_class", - [Plotterly, PolarPlotterly, create_animated_plotterly]) + [Plotterly, PolarPlotterly, AnimatedPlotterly], indirect=True) @pytest.mark.parametrize("_measurements, expected_labels", [(true_measurements, {'Measurements'}), (clutter_measurements, {'Measurements
(Clutter)'}), @@ -463,12 +490,13 @@ def test_plotterlys_plot_measurements_label(plotter_class, _measurements, expect (all_measurements, {'Measurements\n(Detections)', 'Measurements\n(Clutter)'}) ]) -def test_plotter_plot_measurements_label(_measurements, expected_labels): - plotter = Plotter() +def test_plotter_plot_measurements_label(plotter_class, _measurements, expected_labels): + plotter = plotter_class() plotter.plot_measurements(_measurements, [0, 2]) actual_labels = set(plotter.legend_dict.keys()) assert actual_labels == expected_labels -def test_close_all_figures(): - plt.close('all') +# def test_close_all_figures(): +# plt.show() + # plt.close('all') From c28b99a6c4fdc2ca2bf64c4af2f3bebc789a2150 Mon Sep 17 00:00:00 2001 From: G Webb Date: Mon, 19 Aug 2024 15:09:07 +0100 Subject: [PATCH 7/7] Tidied up some of the plotting tests --- stonesoup/tests/test_plotter.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/stonesoup/tests/test_plotter.py b/stonesoup/tests/test_plotter.py index 20d37686a..49ac9615b 100644 --- a/stonesoup/tests/test_plotter.py +++ b/stonesoup/tests/test_plotter.py @@ -106,35 +106,25 @@ ) -@pytest.fixture(scope="function") +@pytest.fixture(scope="module") def plotter_class(request): plotter_class = request.param assert plotter_class in {Plotter, Plotterly, AnimationPlotter, PolarPlotterly, AnimatedPlotterly} - figures_to_close = [] - def _generate_animated_plotterly(*args, **kwargs): return AnimatedPlotterly(*args, timesteps=timesteps, **kwargs) def _generate_plotter(*args, **kwargs): - _plotter = Plotter(*args, **kwargs) - figures_to_close.append(_plotter.fig) - return _plotter - - def _generate_other_plotter(*args, **kwargs): return plotter_class(*args, **kwargs) - if plotter_class is Plotter: + if plotter_class in {Plotter, Plotterly, AnimationPlotter, PolarPlotterly}: yield _generate_plotter elif plotter_class is AnimatedPlotterly: yield _generate_animated_plotterly else: - yield _generate_other_plotter - - for fig in figures_to_close: - plt.close(fig) + raise ValueError("Invalid Plotter type.") # Test functions @@ -490,13 +480,14 @@ def test_plotterlys_plot_measurements_label(plotter_class, _measurements, expect (all_measurements, {'Measurements\n(Detections)', 'Measurements\n(Clutter)'}) ]) -def test_plotter_plot_measurements_label(plotter_class, _measurements, expected_labels): - plotter = plotter_class() +def test_plotter_plot_measurements_label(_measurements, expected_labels): + plotter = Plotter() plotter.plot_measurements(_measurements, [0, 2]) actual_labels = set(plotter.legend_dict.keys()) assert actual_labels == expected_labels -# def test_close_all_figures(): -# plt.show() - # plt.close('all') +def teardown_module(): + """Closes all matplotlib plots. + Without this code plots would remain in the background for the duration of all the tests.""" + plt.close('all')