diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 4b35ab3af..df392b44f 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -369,13 +369,14 @@ impl Processor { /// [`QueryStatusError::DifferentStatus`] and retrieve it's internal state. Returns [`None`] /// if not possible. #[cfg(feature = "in-memory-infra")] - fn downcast_state_error(box_error: crate::error::BoxError) -> Option { + fn downcast_state_error(box_error: &crate::error::BoxError) -> Option { use crate::helpers::ApiError; - let api_error = box_error.downcast::().ok()?; - if let ApiError::QueryStatus(QueryStatusError::DifferentStatus { my_status, .. }) = - *api_error + let api_error = box_error.downcast_ref::(); + if let Some(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + my_status, .. + })) = api_error { - return Some(my_status); + return Some(*my_status); } None } @@ -386,7 +387,7 @@ impl Processor { /// of relying on errors. #[cfg(feature = "in-memory-infra")] fn get_state_from_error( - error: crate::helpers::InMemoryTransportError, + error: &crate::helpers::InMemoryTransportError, ) -> Option { if let crate::helpers::InMemoryTransportError::Rejected { inner, .. } = error { return Self::downcast_state_error(inner); @@ -399,8 +400,8 @@ impl Processor { /// TODO: Ideally broadcast should return a value, that we could use to parse the state instead /// of relying on errors. #[cfg(feature = "real-world-infra")] - fn get_state_from_error(shard_error: crate::net::ShardError) -> Option { - if let crate::net::Error::ShardQueryStatusMismatch { error, .. } = shard_error.source { + fn get_state_from_error(shard_error: &crate::net::ShardError) -> Option { + if let crate::net::Error::ShardQueryStatusMismatch { error, .. } = &shard_error.source { return Some(error.actual); } None @@ -431,17 +432,14 @@ impl Processor { let shard_responses = shard_transport.broadcast(shard_query_status_req).await; if let Err(e) = shard_responses { - // The following silently ignores the cases where the query isn't found. - // TODO: this code is a ticking bomb - it ignores all errors, not just when - // query is not found. If there is no handler, handler responded with an error, etc. - // Moreover, any error may result in client mistakenly assuming that the status - // is completed. - let states: Vec<_> = e - .failures - .into_iter() - .filter_map(|(_si, e)| Self::get_state_from_error(e)) - .collect(); - status = states.into_iter().fold(status, min_status); + for (shard, failure) in &e.failures { + if let Some(other) = Self::get_state_from_error(failure) { + status = min_status(status, other); + } else { + tracing::error!("failed to get status from shard {shard}: {failure:?}"); + return Err(e.into()); + } + } } Ok(status) @@ -1205,9 +1203,12 @@ mod tests { /// * From the standpoint of leader shard in Helper 1 /// * On query_status /// - /// If one of my shards hasn't received the query yet (NoSuchQuery) the leader shouldn't - /// return an error but instead with the min state. + /// If one of my shards hasn't received the query yet (NoSuchQuery) the leader should + /// return an error despite other shards returning their status #[tokio::test] + #[should_panic( + expected = "(ShardIndex(3), Rejected { dest: ShardIndex(3), inner: QueryStatus(NoSuchQuery(QueryId)) })" + )] async fn status_query_doesnt_exist() { fn shard_handle(si: ShardIndex) -> Arc> { create_handler(move |_| async move { @@ -1215,6 +1216,12 @@ mod tests { Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( QueryId, ))) + } else if si == ShardIndex::from(2) { + Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { + query_id: QueryId, + my_status: QueryStatus::Running, + other_status: QueryStatus::Preparing, + })) } else { Ok(HelperResponse::ok()) } @@ -1237,16 +1244,10 @@ mod tests { req, ) .unwrap(); - let r = t - .processor + t.processor .query_status(t.shard_transport.clone_ref(), QueryId) - .await; - if let Err(e) = r { - panic!("Unexpected error {e}"); - } - if let Ok(st) = r { - assert_eq!(QueryStatus::AwaitingInputs, st); - } + .await + .unwrap(); } /// Context: