Skip to content

Commit

Permalink
editoast: fix projection endpoint
Browse files Browse the repository at this point in the history
Co-authored-by: Youness CHRIFI ALAOUI <[email protected]>
  • Loading branch information
flomonster and younesschrifi committed Dec 5, 2024
1 parent 52c374b commit 35b4cba
Showing 1 changed file with 29 additions and 28 deletions.
57 changes: 29 additions & 28 deletions editoast/src/views/train_schedule/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ async fn project_path(
let mut trains_hash_values = vec![];
let mut trains_details = vec![];

for (sim, pathfinding_result) in simulations {
for (train, (sim, pathfinding_result)) in izip!(&trains, simulations) {
let track_ranges = match pathfinding_result {
PathfindingResult::Success(PathfindingResultSuccess {
track_section_ranges,
Expand All @@ -221,6 +221,7 @@ async fn project_path(
} = report_train;

let train_details = TrainSimulationDetails {
train_id: train.id,
positions,
times,
signal_critical_positions,
Expand All @@ -242,17 +243,13 @@ async fn project_path(
let cached_projections: Vec<Option<CachedProjectPathTrainResult>> =
valkey_conn.json_get_bulk(&trains_hash_values).await?;

let mut hit_cache: HashMap<i64, CachedProjectPathTrainResult> = HashMap::new();
let mut miss_cache = HashMap::new();
for (train_details, projection, train_id) in izip!(
trains_details,
cached_projections,
trains.iter().map(|t| t.id)
) {
let mut hit_cache = vec![];
let mut miss_cache = vec![];
for (train_details, projection) in izip!(&trains_details, cached_projections) {
if let Some(cached) = projection {
hit_cache.insert(train_id, cached);
hit_cache.push((cached, train_details.train_id));
} else {
miss_cache.insert(train_id, train_details.clone());
miss_cache.push(train_details.clone());
}
}

Expand All @@ -277,40 +274,40 @@ async fn project_path(
let signal_updates = signal_updates?;

// 3. Store the projection in the cache (using pipeline)
let trains_hash_values: HashMap<_, _> = trains
let trains_hash_values: HashMap<_, _> = trains_details
.iter()
.map(|t| t.id)
.map(|t| t.train_id)
.zip(trains_hash_values)
.collect();
let mut new_items = vec![];
for id in miss_cache.keys() {
let hash = &trains_hash_values[id];
for id in miss_cache.iter().map(|t| t.train_id) {
let hash = &trains_hash_values[&id];
let cached_value = CachedProjectPathTrainResult {
space_time_curves: space_time_curves
.get(id)
.get(&id)
.expect("Space time curves not available for train")
.clone(),
signal_updates: signal_updates
.get(id)
.get(&id)
.expect("Signal update not available for train")
.clone(),
};
hit_cache.insert(*id, cached_value.clone());
hit_cache.push((cached_value.clone(), id));
new_items.push((hash, cached_value));
}
valkey_conn.json_set_bulk(&new_items).await?;

let train_map: HashMap<i64, TrainSchedule> = trains.into_iter().map(|ts| (ts.id, ts)).collect();

// 4.1 Fetch rolling stock length
let mut project_path_result = HashMap::new();
let rolling_stock_length: HashMap<_, _> = rolling_stocks
.into_iter()
.map(|rs| (rs.name, rs.length))
.collect();

// 4.2 Build the projection response
for (id, cached) in hit_cache {
let mut project_path_result = HashMap::new();
for (cached, id) in hit_cache {
let train = train_map.get(&id).expect("Train not found");
let length = rolling_stock_length
.get(&train.rolling_stock_name)
Expand All @@ -332,6 +329,7 @@ async fn project_path(
/// Input for the projection of a train schedule on a path
#[derive(Debug, Clone, Hash)]
struct TrainSimulationDetails {
train_id: i64,
positions: Vec<u64>,
times: Vec<u64>,
train_path: Vec<TrackRange>,
Expand All @@ -346,7 +344,7 @@ async fn compute_batch_signal_updates<'a>(
path_track_ranges: &'a Vec<TrackRange>,
path_routes: &'a Vec<Identifier>,
path_blocks: &'a Vec<Identifier>,
trains_details: &'a HashMap<i64, TrainSimulationDetails>,
trains_details: &'a [TrainSimulationDetails],
) -> Result<HashMap<i64, Vec<SignalUpdate>>> {
if trains_details.is_empty() {
return Ok(HashMap::new());
Expand All @@ -359,13 +357,13 @@ async fn compute_batch_signal_updates<'a>(
blocks: path_blocks,
train_simulations: trains_details
.iter()
.map(|(id, details)| {
.map(|detail| {
(
*id,
detail.train_id,
TrainSimulation {
signal_critical_positions: &details.signal_critical_positions,
zone_updates: &details.zone_updates,
simulation_end_time: details.times[details.times.len() - 1],
signal_critical_positions: &detail.signal_critical_positions,
zone_updates: &detail.zone_updates,
simulation_end_time: detail.times[detail.times.len() - 1],
},
)
})
Expand All @@ -377,14 +375,14 @@ async fn compute_batch_signal_updates<'a>(

/// Compute space time curves of a list of train schedules
async fn compute_batch_space_time_curves<'a>(
trains_details: &HashMap<i64, TrainSimulationDetails>,
trains_details: &Vec<TrainSimulationDetails>,
path_projection: &PathProjection<'a>,
) -> HashMap<i64, Vec<SpaceTimeCurve>> {
let mut space_time_curves = HashMap::new();

for (train_id, train_detail) in trains_details {
for train_detail in trains_details {
space_time_curves.insert(
*train_id,
train_detail.train_id,
compute_space_time_curves(train_detail, path_projection),
);
}
Expand Down Expand Up @@ -584,6 +582,7 @@ mod tests {
];

let project_path_input = TrainSimulationDetails {
train_id: 0,
positions,
times,
train_path,
Expand Down Expand Up @@ -618,6 +617,7 @@ mod tests {
];

let project_path_input = TrainSimulationDetails {
train_id: 0,
positions: positions.clone(),
times: times.clone(),
train_path,
Expand Down Expand Up @@ -655,6 +655,7 @@ mod tests {
let path_projection = PathProjection::new(&path);

let project_path_input = TrainSimulationDetails {
train_id: 0,
positions,
times,
train_path,
Expand Down

0 comments on commit 35b4cba

Please sign in to comment.