-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsort.rs
102 lines (76 loc) · 3.05 KB
/
sort.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::{num::NonZeroU32, time::Duration};
use wgpu_sort::{utils::{download_buffer, guess_workgroup_size}, GPUSorter, SortBuffers};
struct SortStuff{
device:wgpu::Device,
queue:wgpu::Queue,
query_set:wgpu::QuerySet,
query_buffer:wgpu::Buffer,
}
async fn setup()-> SortStuff{
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default());
let adapter = wgpu::util::initialize_adapter_from_env_or_default(&instance, None)
.await
.unwrap();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
required_features: wgpu::Features::TIMESTAMP_QUERY,
required_limits: wgpu::Limits{
max_buffer_size:1<<30,
max_storage_buffer_binding_size:1<<30,
..Default::default()
},
label: None,
},
None,
)
.await
.unwrap();
let capacity = 2;
let query_set = device.create_query_set(&wgpu::QuerySetDescriptor {
label: Some("time stamp query set"),
ty: wgpu::QueryType::Timestamp,
count: capacity,
});
let query_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("query set buffer"),
size: capacity as u64 * std::mem::size_of::<u64>() as u64,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
return SortStuff{device,queue,query_set,query_buffer}
}
async fn sort(context:&SortStuff,sorter:&GPUSorter,buffers:&SortBuffers,n:u32,iters:u32) -> Duration {
let mut encoder = context.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: None,
});
encoder.write_timestamp(&context.query_set, 0);
for _ in 0..iters{
sorter.sort(&mut encoder,&context.queue,buffers,Some(n));
}
encoder.write_timestamp(&context.query_set, 1);
encoder.resolve_query_set(
&context.query_set,
0..2,
&context.query_buffer,
0,
);
let idx = context.queue.submit([encoder.finish()]);
context.device.poll(wgpu::Maintain::WaitForSubmissionIndex(idx));
let timestamps : Vec<u64> = pollster::block_on(download_buffer(&context.query_buffer, &context.device, &context.queue, ..));
let diff_ticks = timestamps[1] - timestamps[0];
let period = context.queue.get_timestamp_period();
let diff_time = Duration::from_nanos((diff_ticks as f32 * period / iters as f32) as u64);
return diff_time;
}
#[pollster::main]
async fn main() {
let context = setup().await;
let subgroup_size = guess_workgroup_size(&context.device, &context.queue).await.expect("could not find a valid subgroup size");
let sorter = GPUSorter::new(&context.device, subgroup_size);
for n in [10_000,100_000,1_000_000,8_000_000,20_000_000]{
let buffers = sorter.create_sort_buffers(&context.device, NonZeroU32::new(n).unwrap());
let d = sort(&context,&sorter, &buffers,n,10000).await;
println!("{n}: {d:?}");
}
}