Skip to content

Commit

Permalink
FabricClient: Make APIs "Send" (#69)
Browse files Browse the repository at this point in the history
Make FabricClient's use of SF raw structs scoped to avoid having them
across await points, and this makes FabricClient APIs "Send" and can be
used anywhere in tokio runtime.
  • Loading branch information
youyuanwu authored Sep 4, 2024
1 parent ff6a9d9 commit 9d611b8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 43 deletions.
36 changes: 20 additions & 16 deletions crates/libs/core/src/client/query_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::{
},
};

#[derive(Debug, Clone)]
pub struct QueryClient {
com: IFabricQueryClient10,
}
Expand Down Expand Up @@ -99,8 +100,9 @@ impl QueryClient {
timeout: Duration,
cancellation_token: Option<crate::sync::CancellationToken>,
) -> windows_core::Result<NodeList> {
let fu;
{
// Note that the SF raw structs are scoped to avoid having them across await points.
// This makes api Send. All FabricClient api should follow this pattern.
let com = {
let ex3 = FABRIC_NODE_QUERY_DESCRIPTION_EX3 {
MaxResults: desc.paged_query.max_results.unwrap_or(0),
Reserved: std::ptr::null_mut(),
Expand All @@ -120,14 +122,14 @@ impl QueryClient {
NodeNameFilter: get_pcwstr_from_opt(&desc.node_name_filter),
Reserved: std::ptr::addr_of!(ex1) as *mut c_void,
};
fu = self.get_node_list_internal(
self.get_node_list_internal(
&arg,
timeout.as_millis().try_into().unwrap(),
cancellation_token,
);
)
}
let res = fu.await??;
Ok(NodeList::from_com(res))
.await??;
Ok(NodeList::from_com(com))
}

pub async fn get_partition_list(
Expand All @@ -136,11 +138,12 @@ impl QueryClient {
timeout: Duration,
cancellation_token: Option<CancellationToken>,
) -> crate::Result<ServicePartitionList> {
let raw: FABRIC_SERVICE_PARTITION_QUERY_DESCRIPTION = desc.into();
let mili = timeout.as_millis() as u32;
let com = self
.get_partition_list_internal(&raw, mili, cancellation_token)
.await??;
let com = {
let raw: FABRIC_SERVICE_PARTITION_QUERY_DESCRIPTION = desc.into();
let mili = timeout.as_millis() as u32;
self.get_partition_list_internal(&raw, mili, cancellation_token)
}
.await??;
Ok(ServicePartitionList::new(com))
}

Expand All @@ -150,11 +153,12 @@ impl QueryClient {
timeout: Duration,
cancellation_token: Option<CancellationToken>,
) -> crate::Result<ServiceReplicaList> {
let raw: FABRIC_SERVICE_REPLICA_QUERY_DESCRIPTION = desc.into();
let mili = timeout.as_millis() as u32;
let com = self
.get_replica_list_internal(&raw, mili, cancellation_token)
.await??;
let com = {
let raw: FABRIC_SERVICE_REPLICA_QUERY_DESCRIPTION = desc.into();
let mili = timeout.as_millis() as u32;
self.get_replica_list_internal(&raw, mili, cancellation_token)
}
.await??;
Ok(ServiceReplicaList::new(com))
}
}
59 changes: 33 additions & 26 deletions crates/libs/core/src/client/svc_mgmt_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::{
};

// Service Management Client
#[derive(Debug, Clone)]
pub struct ServiceManagementClient {
com: IFabricServiceManagementClient6,
}
Expand Down Expand Up @@ -152,22 +153,23 @@ impl ServiceManagementClient {
timeout: Duration,
cancellation_token: Option<CancellationToken>,
) -> windows_core::Result<ResolvedServicePartition> {
let uri = FABRIC_URI(name.as_ptr() as *mut u16);
// supply prev as null if not present
let prev_opt = prev.map(|x| &x.com);

let part_key_opt = key_type.get_raw_opt();

let fu = self.resolve_service_partition_internal(
uri,
key_type.into(),
part_key_opt,
prev_opt,
timeout.as_millis().try_into().unwrap(),
cancellation_token,
);

let com = fu.await??;
let com = {
let uri = FABRIC_URI(name.as_ptr() as *mut u16);
// supply prev as null if not present
let prev_opt = prev.map(|x| &x.com);

let part_key_opt = key_type.get_raw_opt();

self.resolve_service_partition_internal(
uri,
key_type.into(),
part_key_opt,
prev_opt,
timeout.as_millis().try_into().unwrap(),
cancellation_token,
)
}
.await??;
let res = ResolvedServicePartition::from_com(com);
Ok(res)
}
Expand All @@ -182,9 +184,11 @@ impl ServiceManagementClient {
timeout: Duration,
cancellation_token: Option<CancellationToken>,
) -> crate::Result<()> {
let raw: FABRIC_RESTART_REPLICA_DESCRIPTION = desc.into();
self.restart_replica_internal(&raw, timeout.as_millis() as u32, cancellation_token)
.await?
{
let raw: FABRIC_RESTART_REPLICA_DESCRIPTION = desc.into();
self.restart_replica_internal(&raw, timeout.as_millis() as u32, cancellation_token)
}
.await?
}

/// This API gives a running replica the chance to cleanup its state and be gracefully shutdown.
Expand All @@ -198,9 +202,11 @@ impl ServiceManagementClient {
timeout: Duration,
cancellation_token: Option<CancellationToken>,
) -> crate::Result<()> {
let raw: FABRIC_REMOVE_REPLICA_DESCRIPTION = desc.into();
self.remove_replica_internal(&raw, timeout.as_millis() as u32, cancellation_token)
.await?
{
let raw: FABRIC_REMOVE_REPLICA_DESCRIPTION = desc.into();
self.remove_replica_internal(&raw, timeout.as_millis() as u32, cancellation_token)
}
.await?
}

/// Remarks:
Expand All @@ -219,14 +225,15 @@ impl ServiceManagementClient {
timeout: Duration,
cancellation_token: Option<CancellationToken>,
) -> crate::Result<FilterIdHandle> {
let raw: FABRIC_SERVICE_NOTIFICATION_FILTER_DESCRIPTION = desc.into();
let id = self
.register_service_notification_filter_internal(
let id = {
let raw: FABRIC_SERVICE_NOTIFICATION_FILTER_DESCRIPTION = desc.into();
self.register_service_notification_filter_internal(
&raw,
timeout.as_millis() as u32,
cancellation_token,
)
.await??;
}
.await??;
Ok(FilterIdHandle { id })
}

Expand Down
9 changes: 8 additions & 1 deletion crates/libs/core/src/client/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ async fn test_fabric_client() {
},
..Default::default()
};
let list = qc.get_node_list(&desc, timeout, None).await.unwrap();
let qc_cp = qc.clone();
let list = tokio::spawn(async move {
// make sure api is Send.
qc_cp.get_node_list(&desc, timeout, None).await
})
.await
.unwrap()
.unwrap();
paging_status = list.get_paging_status();
let v = list.iter().collect::<Vec<_>>();
assert_ne!(v.len(), 0);
Expand Down

0 comments on commit 9d611b8

Please sign in to comment.