Skip to content

Commit

Permalink
Check more validation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Dec 30, 2024
1 parent 62c7184 commit 7ec9671
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
36 changes: 29 additions & 7 deletions pydawn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def create_shader_module(device, source):

return shader_module

def create_bind_group_layout(device, entries):
def create_bind_group_layout(device, entries, validate=True):
webgpuEntries = []

for entry in entries:
Expand All @@ -191,17 +191,32 @@ def create_bind_group_layout(device, entries):
desc.entryCount = len(webgpuEntries)
desc.entries = ctypes.cast(entries_array, ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry))

return webgpu.wgpuDeviceCreateBindGroupLayout(device, desc)
webgpu.wgpuDevicePushErrorScope(device, webgpu.WGPUErrorFilter_Validation)
ret = webgpu.wgpuDeviceCreateBindGroupLayout(device, desc)
layout_error = pop_error(device)

if layout_error and validate:
raise RuntimeError(f"Error creating bind group layout: {layout_error}")

return ret

def create_pipeline_layout(device, bind_group_layouts):
def create_pipeline_layout(device, bind_group_layouts, validate=True):
pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor()
pipeline_layout_desc.bindGroupLayoutCount = len(bind_group_layouts)
bind_group_array_type = webgpu.WGPUBindGroupLayout * len(bind_group_layouts)
bind_groups_ctype = bind_group_array_type(*bind_group_layouts)
pipeline_layout_desc.bindGroupLayouts = bind_groups_ctype
return webgpu.wgpuDeviceCreatePipelineLayout(device, pipeline_layout_desc)

def create_bind_group(device, layout, entries):
webgpu.wgpuDevicePushErrorScope(device, webgpu.WGPUErrorFilter_Validation)
ret = webgpu.wgpuDeviceCreatePipelineLayout(device, pipeline_layout_desc)
layout_error = pop_error(device)

if layout_error and validate:
raise RuntimeError(f"Error creating pipeline layout: {layout_error}")

return ret

def create_bind_group(device, layout, entries, validate=True):
bind_group_desc = webgpu.WGPUBindGroupDescriptor()
bind_group_desc.layout = layout
bind_group_desc.entryCount = len(entries)
Expand All @@ -218,8 +233,15 @@ def create_bind_group(device, layout, entries):
entries_array_type = webgpu.WGPUBindGroupEntry * len(webgpu_entries)
entries_array = entries_array_type(*webgpu_entries)
bind_group_desc.entries = entries_array

return webgpu.wgpuDeviceCreateBindGroup(device, bind_group_desc)

webgpu.wgpuDevicePushErrorScope(device, webgpu.WGPUErrorFilter_Validation)
ret = webgpu.wgpuDeviceCreateBindGroup(device, bind_group_desc)
bind_group_error = pop_error(device)

if bind_group_error and validate:
raise RuntimeError(f"Error creating bind group: {bind_group_error}")

return ret

def create_compute_pipeline(device, layout, compute):
compute_desc = webgpu.WGPUComputePipelineDescriptor()
Expand Down
25 changes: 25 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,30 @@ def test_map_buffer_error(self):

self.assertIn("Failed to map buffer", str(ctx.exception))

def test_create_bind_group_layout_error(self):
with self.assertRaises(RuntimeError) as ctx:
num_bufs = 11
bind_group_layouts = [{
"binding": i,
"visibility": webgpu.WGPUShaderStage_Compute,
"buffer": {"type": webgpu.WGPUBufferBindingType_Storage},
} for i in range(num_bufs)]
utils.create_bind_group_layout(self.device, bind_group_layouts)

self.assertIn(f"The number of storage buffers ({num_bufs}) in the Compute stage exceeds the maximum per-stage limit", str(ctx.exception))

def test_create_bind_group_layout_error(self):
with self.assertRaises(RuntimeError) as ctx:
bind_group_layouts = [{
"binding": i,
"visibility": webgpu.WGPUShaderStage_Fragment,
"buffer": {"type": webgpu.WGPUBufferBindingType_Storage},
} for i in range(11)]
# Disable validation when creating bind group layout to catch the error later
bind_group_layout = utils.create_bind_group_layout(self.device, bind_group_layouts, validate=False)
utils.create_pipeline_layout(self.device, [bind_group_layout])

self.assertIn(f"Error creating pipeline layout: [Invalid BindGroupLayout (unlabeled)] is invalid.", str(ctx.exception))

if __name__ == "__main__":
unittest.main()

0 comments on commit 7ec9671

Please sign in to comment.