Skip to content

Commit

Permalink
Cleanup cuda algorithm coersion
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Jan 11, 2024
1 parent 65ed1c8 commit b6dd445
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,7 @@ fn debug_display_sampler() {

assert_eq!(
&alloc::format!("{sampler:?}"),
"DynamicAliasMethodIndexedSampler { exponents: [], total_weight: 0.0 }"
"DynamicAliasMethodIndexedSampler { exponents: [], total_weight: 0.0, .. }"
);

for i in (1..=6_u8).rev() {
Expand All @@ -1046,7 +1046,7 @@ fn debug_display_sampler() {

assert_eq!(
&alloc::format!("{sampler:?}"),
"DynamicAliasMethodIndexedSampler { exponents: [2, 1, 0], total_weight: 21.0 }"
"DynamicAliasMethodIndexedSampler { exponents: [2, 1, 0], total_weight: 21.0, .. }"
);

let mut sampler_clone = unsafe { sampler.backup_unchecked() };
Expand All @@ -1062,11 +1062,11 @@ fn debug_display_sampler() {

assert_eq!(
&alloc::format!("{sampler:?}"),
"DynamicAliasMethodIndexedSampler { exponents: [2, 1, 0], total_weight: 18.0 }"
"DynamicAliasMethodIndexedSampler { exponents: [2, 1, 0], total_weight: 18.0, .. }"
);
assert_eq!(
&alloc::format!("{sampler_clone:?}"),
"DynamicAliasMethodIndexedSampler { exponents: [2, 1], total_weight: 20.0 }"
"DynamicAliasMethodIndexedSampler { exponents: [2, 1], total_weight: 20.0, .. }"
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ fn debug_display_sampler() {

assert_eq!(
&alloc::format!("{sampler:?}"),
"DynamicAliasMethodStackSampler { exponents: [], total_weight: 0.0 }"
"DynamicAliasMethodStackSampler { exponents: [], total_weight: 0.0, .. }"
);

for i in (1..=6_u8).rev() {
Expand All @@ -544,7 +544,7 @@ fn debug_display_sampler() {

assert_eq!(
&alloc::format!("{sampler:?}"),
"DynamicAliasMethodStackSampler { exponents: [2, 1, 0], total_weight: 21.0 }"
"DynamicAliasMethodStackSampler { exponents: [2, 1, 0], total_weight: 21.0, .. }"
);

let mut sampler_clone = unsafe { sampler.backup_unchecked() };
Expand All @@ -560,11 +560,11 @@ fn debug_display_sampler() {

assert_eq!(
&alloc::format!("{sampler:?}"),
"DynamicAliasMethodStackSampler { exponents: [2, 1, 0], total_weight: 18.0 }"
"DynamicAliasMethodStackSampler { exponents: [2, 1, 0], total_weight: 18.0, .. }"
);
assert_eq!(
&alloc::format!("{sampler_clone:?}"),
"DynamicAliasMethodStackSampler { exponents: [2, 1], total_weight: 20.0 }"
"DynamicAliasMethodStackSampler { exponents: [2, 1], total_weight: 20.0, .. }"
);
}

Expand Down
1 change: 0 additions & 1 deletion rustcoalescence/algorithms/cuda/cpu-kernel/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![deny(clippy::pedantic)]
#![feature(c_str_literals)]
#![feature(min_specialization)]
#![allow(long_running_const_eval)]
#![recursion_limit = "1024"]

Expand Down
50 changes: 49 additions & 1 deletion rustcoalescence/algorithms/cuda/cpu-kernel/src/link.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,52 @@
use crate::SimulationKernelPtx;
use necsim_core::{
cogs::{
CoalescenceSampler, DispersalSampler, EmigrationExit, Habitat, ImmigrationEntry,
LineageStore, MathsCore, PrimeableRng, SpeciationProbability, TurnoverRate,
},
reporter::boolean::Boolean,
};

use necsim_impls_no_std::cogs::{
active_lineage_sampler::singular::SingularActiveLineageSampler,
event_sampler::tracking::MinSpeciationTrackingEventSampler,
};

use rust_cuda::lend::RustToCuda;

#[allow(clippy::type_complexity)]
pub struct SimulationKernelPtx<
M: MathsCore + Sync,
H: Habitat<M> + RustToCuda + Sync,
G: PrimeableRng<M> + RustToCuda + Sync,
S: LineageStore<M, H> + RustToCuda + Sync,
X: EmigrationExit<M, H, G, S> + RustToCuda + Sync,
D: DispersalSampler<M, H, G> + RustToCuda + Sync,
C: CoalescenceSampler<M, H, S> + RustToCuda + Sync,
T: TurnoverRate<M, H> + RustToCuda + Sync,
N: SpeciationProbability<M, H> + RustToCuda + Sync,
E: MinSpeciationTrackingEventSampler<M, H, G, S, X, D, C, T, N> + RustToCuda + Sync,
I: ImmigrationEntry<M> + RustToCuda + Sync,
A: SingularActiveLineageSampler<M, H, G, S, X, D, C, T, N, E, I> + RustToCuda + Sync,
ReportSpeciation: Boolean,
ReportDispersal: Boolean,
>(
std::marker::PhantomData<(
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
ReportSpeciation,
ReportDispersal,
)>,
);

macro_rules! link_kernel {
($habitat:ty, $dispersal:ty, $turnover:ty, $speciation:ty) => {
Expand Down
169 changes: 148 additions & 21 deletions rustcoalescence/algorithms/cuda/cpu-kernel/src/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use necsim_core::{
CoalescenceSampler, DispersalSampler, EmigrationExit, Habitat, ImmigrationEntry,
LineageStore, MathsCore, PrimeableRng, SpeciationProbability, TurnoverRate,
},
reporter::boolean::Boolean, // reporter::boolean::{Boolean, False, True},
reporter::boolean::{Boolean, False, True},
};
use necsim_impls_no_std::cogs::{
active_lineage_sampler::singular::SingularActiveLineageSampler,
Expand All @@ -21,11 +21,7 @@ use crate::SimulationKernelPtx;
// If `Kernel` is implemented for `ReportSpeciation` x `ReportDispersal`, i.e.
// for {`False`, `True`} x {`False`, `True`} then it is implemented for all
// `Boolean`s. However, Rust does not recognise that `Boolean` is closed over
// {`False`, `True`}. These default impls provide the necessary coersion.

extern "C" {
fn unreachable_cuda_simulation_linking_reporter() -> !;
}
// {`False`, `True`}. This explicit impl provides the necessary coersion.

#[allow(clippy::trait_duplication_in_bounds)]
unsafe impl<
Expand All @@ -46,23 +42,154 @@ unsafe impl<
>
CompiledKernelPtx<
simulate<M, H, G, S, X, D, C, T, N, E, I, A, ReportSpeciation, ReportDispersal>,
>
for SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, ReportSpeciation, ReportDispersal>
//where
// SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, False, False>:
// CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, False, False>>,
// SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, False, True>:
// CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, False, True>>,
// SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, True, False>:
// CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, True, False>>,
// SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, True, True>:
// CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, True, True>>,
> for SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, ReportSpeciation, ReportDispersal>
where
crate::link::SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, False, False>:
CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, False, False>>,
crate::link::SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, False, True>:
CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, False, True>>,
crate::link::SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, True, False>:
CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, True, False>>,
crate::link::SimulationKernelPtx<M, H, G, S, X, D, C, T, N, E, I, A, True, True>:
CompiledKernelPtx<simulate<M, H, G, S, X, D, C, T, N, E, I, A, True, True>>,
{
default fn get_ptx() -> &'static CStr {
unsafe { unreachable_cuda_simulation_linking_reporter() }
#[inline]
fn get_ptx() -> &'static CStr {
match (ReportSpeciation::VALUE, ReportDispersal::VALUE) {
(false, false) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
False,
False,
>::get_ptx(),
(false, true) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
False,
True,
>::get_ptx(),
(true, false) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
True,
False,
>::get_ptx(),
(true, true) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
True,
True,
>::get_ptx(),
}
}

default fn get_entry_point() -> &'static CStr {
unsafe { unreachable_cuda_simulation_linking_reporter() }
#[inline]
fn get_entry_point() -> &'static CStr {
match (ReportSpeciation::VALUE, ReportDispersal::VALUE) {
(false, false) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
False,
False,
>::get_entry_point(),
(false, true) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
False,
True,
>::get_entry_point(),
(true, false) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
True,
False,
>::get_entry_point(),
(true, true) => crate::link::SimulationKernelPtx::<
M,
H,
G,
S,
X,
D,
C,
T,
N,
E,
I,
A,
True,
True,
>::get_entry_point(),
}
}
}
Loading

0 comments on commit b6dd445

Please sign in to comment.