Skip to content

Commit

Permalink
Merge pull request #1048 from timothy-glover/sensor_management_reward…
Browse files Browse the repository at this point in the history
…_mod

Update UncertaintyRewardFunction to return tracks and minor modifications
  • Loading branch information
sdhiscocks authored Jun 19, 2024
2 parents f67aead + 2e262df commit 9a95387
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
# General imports and environment setup
import numpy as np
from datetime import datetime, timedelta
import random

np.random.seed(1991)

Expand Down Expand Up @@ -253,13 +252,15 @@ def constraint_function(particle_state):
reward_updater = ParticleUpdater(measurement_model=None)

# Myopic benchmark approach
reward_funcA = ExpectedKLDivergence(updater=reward_updater)
reward_funcA = ExpectedKLDivergence(updater=reward_updater, measurement_noise=True)
sensormanagerA = BruteForceSensorManager(sensors={gas_sensorA},
platforms={sensor_platformA},
reward_function=reward_funcA)

# MCTS with rollout approach
reward_funcB = ExpectedKLDivergence(updater=reward_updater, return_tracks=True)
reward_funcB = ExpectedKLDivergence(updater=reward_updater,
measurement_noise=True,
return_tracks=True)
sensormanagerB = MCTSRolloutSensorManager(sensors={gas_sensorB},
platforms={sensor_platformB},
reward_function=reward_funcB,
Expand Down
46 changes: 37 additions & 9 deletions stonesoup/sensormanager/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..updater.base import Updater
from ..updater.particle import ParticleUpdater
from ..resampler.particle import SystematicResampler
from ..types.state import State
from ..types.groundtruth import GroundTruthState
from ..dataassociator.base import DataAssociator


Expand Down Expand Up @@ -71,6 +71,14 @@ class UncertaintyRewardFunction(RewardFunction):
method_sum: bool = Property(default=True, doc="Determines method of calculating reward."
"Default calculates sum across all targets."
"Otherwise calculates mean of all targets.")
return_tracks: bool = Property(default=False,
doc="A flag for allowing the predicted track, "
"used to calculate the reward, to be "
"returned.")
measurement_noise: bool = Property(default=False,
doc="Decide whether or not to apply measurement model "
"noise to the predicted measurements for sensor "
"management.")

def __call__(self, config: Mapping[Sensor, Sequence[Action]], tracks: Set[Track],
metric_time: datetime.datetime, *args, **kwargs):
Expand Down Expand Up @@ -116,8 +124,13 @@ def __call__(self, config: Mapping[Sensor, Sequence[Action]], tracks: Set[Track]
for sensor in predicted_sensors:

# Assumes one detection per track
detections = {detection.groundtruth_path: detection
for detection in sensor.measure(predicted_tracks, noise=False)
detections = {predicted_track: detection
for detection in
sensor.measure({GroundTruthState(predicted_track.mean,
timestamp=predicted_track.timestamp,
metadata=predicted_track.metadata)},
noise=self.measurement_noise)
for predicted_track in predicted_tracks
if isinstance(detection, TrueDetection)}

for predicted_track, detection in detections.items():
Expand All @@ -143,7 +156,10 @@ def __call__(self, config: Mapping[Sensor, Sequence[Action]], tracks: Set[Track]
config_metric /= len(detections)

# Return value of configuration metric
return config_metric
if self.return_tracks:
return config_metric, predicted_tracks
else:
return config_metric


class ExpectedKLDivergence(RewardFunction):
Expand Down Expand Up @@ -183,6 +199,11 @@ class ExpectedKLDivergence(RewardFunction):
"used to calculate the reward, to be "
"returned.")

measurement_noise: bool = Property(default=False,
doc="Decide whether or not to apply measurement model "
"noise to the predicted measurements for sensor "
"management.")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.KLD = KLDivergence()
Expand Down Expand Up @@ -281,9 +302,11 @@ def _generate_detections(self, predicted_tracks, sensors, timestamp=None):
for sensor in sensors:
detections = {}
for predicted_track in predicted_tracks:
tmp_detection = sensor.measure({State(predicted_track.mean,
timestamp=predicted_track.timestamp)},
noise=True)
tmp_detection = sensor.measure(
{GroundTruthState(predicted_track.mean,
timestamp=predicted_track.timestamp,
metadata=predicted_track.metadata)},
noise=self.measurement_noise)
detections.update({predicted_track: tmp_detection})

if self.data_associator:
Expand Down Expand Up @@ -327,6 +350,8 @@ class MultiUpdateExpectedKLDivergence(ExpectedKLDivergence):
doc="Number of measurements to generate from each "
"track prediction. This should be > 1.")

measurement_noise: bool = Property(default=True)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.KLD = KLDivergence()
Expand All @@ -353,8 +378,11 @@ def _generate_detections(self, predicted_tracks, sensors, timestamp=None):
nparts=self.updates_per_track)
tmp_detections = set()
for state in measurement_sources.state_vector:
tmp_detections.update(sensor.measure({State(state, timestamp=timestamp)},
noise=True))
tmp_detections.update(
sensor.measure({GroundTruthState(state,
timestamp=timestamp,
metadata=predicted_track.metadata)},
noise=self.measurement_noise))

detections.update({predicted_track: tmp_detections})
all_detections.update({sensor: detections})
Expand Down
29 changes: 28 additions & 1 deletion stonesoup/sensormanager/tests/test_sensormanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,38 @@ def test_sensor_manager_with_platform(params):
np.diag([1.5, 0.25, 1.5, 0.25]
+ np.random.normal(0, 5e-4, 4))), # track2_state2
MCTSBestChildPolicyEnum.MAXCREWARD, # best_child_policy
), (
ParticlePredictor, # predictor_obj
ParticleUpdater, # updater_obj
None, # hypothesiser
None, # associator
UncertaintyRewardFunction, # reward_function_obj
ParticleState(state_vector=StateVectors(np.random.multivariate_normal(
mean=np.array([1, 1, 1, 1]),
cov=np.diag([1.5, 0.25, 1.5, 0.25]),
size=100).T),
weight=np.array([1/100]*100)), # track1_state1
ParticleState(state_vector=StateVectors(np.random.multivariate_normal(
mean=np.array([2, 1.5, 2, 1.5]),
cov=np.diag([3, 0.5, 3, 0.5]),
size=100).T),
weight=np.array([1/100]*100)), # track1_state2
ParticleState(state_vector=StateVectors(np.random.multivariate_normal(
mean=np.array([-1, 1, -1, 1]),
cov=np.diag([3, 0.5, 3, 0.5]),
size=100).T),
weight=np.array([1/100]*100)), # track2_state1
ParticleState(state_vector=StateVectors(np.random.multivariate_normal(
mean=np.array([2, 1.5, 2, 1.5]),
cov=np.diag([1.5, 0.25, 1.5, 0.25]),
size=100).T),
weight=np.array([1/100]*100)), # track2_state2
'max_cumulative_reward', # best_child_policy
)
],
ids=['KLDivergenceMCTSNoAssociation', 'KLDivergenceMCTSAssociation',
'KLDivergenceMCTSGaussianTest', 'KLDMCTSGaussianPolicy1', 'KLDMCTSGaussianPolicy2',
'KLDMCTSGaussianEnum']
'KLDMCTSGaussianEnum', 'UncertaintyMCTSTest']
)
def test_mcts_sensor_managers(predictor_obj, updater_obj, hypothesiser_obj, associator_obj,
reward_function_obj, track1_state1, track1_state2, track2_state1,
Expand Down

0 comments on commit 9a95387

Please sign in to comment.