diff --git a/pydawn/utils.py b/pydawn/utils.py index 95204b1..7a18a28 100644 --- a/pydawn/utils.py +++ b/pydawn/utils.py @@ -172,7 +172,7 @@ def create_shader_module(device, source): return shader_module -def create_bind_group_layout(device, entries): +def create_bind_group_layout(device, entries, validate=True): webgpuEntries = [] for entry in entries: @@ -191,17 +191,32 @@ def create_bind_group_layout(device, entries): desc.entryCount = len(webgpuEntries) desc.entries = ctypes.cast(entries_array, ctypes.POINTER(webgpu.WGPUBindGroupLayoutEntry)) - return webgpu.wgpuDeviceCreateBindGroupLayout(device, desc) + webgpu.wgpuDevicePushErrorScope(device, webgpu.WGPUErrorFilter_Validation) + ret = webgpu.wgpuDeviceCreateBindGroupLayout(device, desc) + layout_error = pop_error(device) + + if layout_error and validate: + raise RuntimeError(f"Error creating bind group layout: {layout_error}") + + return ret -def create_pipeline_layout(device, bind_group_layouts): +def create_pipeline_layout(device, bind_group_layouts, validate=True): pipeline_layout_desc = webgpu.WGPUPipelineLayoutDescriptor() pipeline_layout_desc.bindGroupLayoutCount = len(bind_group_layouts) bind_group_array_type = webgpu.WGPUBindGroupLayout * len(bind_group_layouts) bind_groups_ctype = bind_group_array_type(*bind_group_layouts) pipeline_layout_desc.bindGroupLayouts = bind_groups_ctype - return webgpu.wgpuDeviceCreatePipelineLayout(device, pipeline_layout_desc) -def create_bind_group(device, layout, entries): + webgpu.wgpuDevicePushErrorScope(device, webgpu.WGPUErrorFilter_Validation) + ret = webgpu.wgpuDeviceCreatePipelineLayout(device, pipeline_layout_desc) + layout_error = pop_error(device) + + if layout_error and validate: + raise RuntimeError(f"Error creating pipeline layout: {layout_error}") + + return ret + +def create_bind_group(device, layout, entries, validate=True): bind_group_desc = webgpu.WGPUBindGroupDescriptor() bind_group_desc.layout = layout bind_group_desc.entryCount = len(entries) @@ -218,8 +233,15 @@ def create_bind_group(device, layout, entries): entries_array_type = webgpu.WGPUBindGroupEntry * len(webgpu_entries) entries_array = entries_array_type(*webgpu_entries) bind_group_desc.entries = entries_array - - return webgpu.wgpuDeviceCreateBindGroup(device, bind_group_desc) + + webgpu.wgpuDevicePushErrorScope(device, webgpu.WGPUErrorFilter_Validation) + ret = webgpu.wgpuDeviceCreateBindGroup(device, bind_group_desc) + bind_group_error = pop_error(device) + + if bind_group_error and validate: + raise RuntimeError(f"Error creating bind group: {bind_group_error}") + + return ret def create_compute_pipeline(device, layout, compute): compute_desc = webgpu.WGPUComputePipelineDescriptor() diff --git a/test/test_utils.py b/test/test_utils.py index fe018e2..be31497 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -77,5 +77,30 @@ def test_map_buffer_error(self): self.assertIn("Failed to map buffer", str(ctx.exception)) + def test_create_bind_group_layout_error(self): + with self.assertRaises(RuntimeError) as ctx: + num_bufs = 11 + bind_group_layouts = [{ + "binding": i, + "visibility": webgpu.WGPUShaderStage_Compute, + "buffer": {"type": webgpu.WGPUBufferBindingType_Storage}, + } for i in range(num_bufs)] + utils.create_bind_group_layout(self.device, bind_group_layouts) + + self.assertIn(f"The number of storage buffers ({num_bufs}) in the Compute stage exceeds the maximum per-stage limit", str(ctx.exception)) + + def test_create_bind_group_layout_error(self): + with self.assertRaises(RuntimeError) as ctx: + bind_group_layouts = [{ + "binding": i, + "visibility": webgpu.WGPUShaderStage_Fragment, + "buffer": {"type": webgpu.WGPUBufferBindingType_Storage}, + } for i in range(11)] + # Disable validation when creating bind group layout to catch the error later + bind_group_layout = utils.create_bind_group_layout(self.device, bind_group_layouts, validate=False) + utils.create_pipeline_layout(self.device, [bind_group_layout]) + + self.assertIn(f"Error creating pipeline layout: [Invalid BindGroupLayout (unlabeled)] is invalid.", str(ctx.exception)) + if __name__ == "__main__": unittest.main()