Skip to content

Commit

Permalink
misc: use base printer infrastructure in WGSL printing
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Dec 11, 2024
1 parent 414bcb0 commit b02c57c
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 176 deletions.
140 changes: 59 additions & 81 deletions tests/backend/wgsl/test_wgsl_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,177 +10,155 @@


def test_gpu_global_id():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

global_id_x = gpu.GlobalIdOp(gpu.DimensionAttr(gpu.DimensionEnum.X))
printer.print(global_id_x)

printer = WGSLPrinter()
printer.print(global_id_x, file)

assert "let v0: u32 = global_invocation_id.x;" in file.getvalue()
assert "let v0: u32 = global_invocation_id.x;" in stream.getvalue()


def test_gpu_thread_id():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

thread_id_x = gpu.ThreadIdOp(gpu.DimensionAttr(gpu.DimensionEnum.X))
printer.print(thread_id_x)

printer = WGSLPrinter()
printer.print(thread_id_x, file)

assert "let v0: u32 = local_invocation_id.x;" in file.getvalue()
assert "let v0: u32 = local_invocation_id.x;" in stream.getvalue()


def test_gpu_block_id():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

block_id_x = gpu.BlockIdOp(gpu.DimensionAttr(gpu.DimensionEnum.X))
printer.print(block_id_x)

printer = WGSLPrinter()
printer.print(block_id_x, file)

assert "let v0: u32 = workgroup_id.x;" in file.getvalue()
assert "let v0: u32 = workgroup_id.x;" in stream.getvalue()


def test_gpu_grid_dim():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

num_workgroups = gpu.GridDimOp(gpu.DimensionAttr(gpu.DimensionEnum.X))
printer.print(num_workgroups)

printer = WGSLPrinter()
printer.print(num_workgroups, file)

assert "let v0: u32 = num_workgroups.x;" in file.getvalue()
assert "let v0: u32 = num_workgroups.x;" in stream.getvalue()


def test_arith_constant_unsigned():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

cst = arith.ConstantOp(IntegerAttr(42, IndexType()))

printer = WGSLPrinter()
printer.print(cst, file)

assert "let v0 : u32 = 42u;" in file.getvalue()
printer.print(cst)


def test_arith_constant_unsigned_neg():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

cst = arith.ConstantOp(IntegerAttr(-1, IndexType()))
cst.result.name_hint = "temp"
printer.print(cst)

printer = WGSLPrinter()
printer.print(cst, file)

assert "let vtemp : u32 = 4294967295u;" in file.getvalue()
assert "let vtemp : u32 = 4294967295u;" in stream.getvalue()


def test_arith_constant_signed():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

cst = arith.ConstantOp(IntegerAttr(42, IntegerType(32)))
cst.result.name_hint = "temp"
printer.print(cst)

printer = WGSLPrinter()
printer.print(cst, file)

assert "let vtemp : i32 = 42;" in file.getvalue()
assert "let vtemp : i32 = 42;" in stream.getvalue()


def test_arith_addi():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

addi = arith.AddiOp(lhs_op, rhs_op)
printer.print(addi)

printer = WGSLPrinter()
printer.print(addi, file)

assert "let v0 = v1 + v2;" in file.getvalue()
assert "let v0 = v1 + v2;" in stream.getvalue()


def test_arith_subi():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

subi = arith.SubiOp(lhs_op, rhs_op)
printer.print(subi)

printer = WGSLPrinter()
printer.print(subi, file)

assert "let v0 = v1 - v2;" in file.getvalue()
assert "let v0 = v1 - v2;" in stream.getvalue()


def test_arith_muli():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

muli = arith.MuliOp(lhs_op, rhs_op)
printer.print(muli)

printer = WGSLPrinter()
printer.print(muli, file)

assert "let v0 = v1 * v2;" in file.getvalue()
assert "let v0 = v1 * v2;" in stream.getvalue()


def test_arith_addf():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

addf = arith.AddfOp(lhs_op, rhs_op)
printer.print(addf)

printer = WGSLPrinter()
printer.print(addf, file)

assert "let v0 = v1 + v2;" in file.getvalue()
assert "let v0 = v1 + v2;" in stream.getvalue()


def test_arith_subf():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

subf = arith.SubfOp(lhs_op, rhs_op)
printer.print(subf)

printer = WGSLPrinter()
printer.print(subf, file)

assert "let v0 = v1 - v2;" in file.getvalue()
assert "let v0 = v1 - v2;" in stream.getvalue()


def test_arith_mulf():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

mulf = arith.MulfOp(lhs_op, rhs_op)
printer.print(mulf)

printer = WGSLPrinter()
printer.print(mulf, file)

assert "let v0 = v1 * v2;" in file.getvalue()
assert "let v0 = v1 * v2;" in stream.getvalue()


def test_memref_load():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

memref_type = memref.MemRefType(i32, [10, 10])

memref_val = TestSSAValue(memref_type)

load = memref.LoadOp.get(memref_val, [lhs_op.res[0], rhs_op.res[0]])
printer.print(load)

printer = WGSLPrinter()
printer.print(load, file)

assert "let v1 = v0[10u * v1 + 1u * v2];" in file.getvalue()
assert "let v1 = v0[10u * v1 + 1u * v2];" in stream.getvalue()


def test_memref_store():
file = StringIO("")
stream = StringIO()
printer = WGSLPrinter(stream=stream)

memref_type = memref.MemRefType(i32, [10, 10])

memref_val = TestSSAValue(memref_type)

load = memref.LoadOp.get(memref_val, [lhs_op.res[0], rhs_op.res[0]])

store = memref.StoreOp.get(load.res, memref_val, [lhs_op.res[0], rhs_op.res[0]])
printer.print(store)

printer = WGSLPrinter()
printer.print(store, file)

assert "v1[10u * v1 + 1u * v2] = v0;" in file.getvalue()
assert "v1[10u * v1 + 1u * v2] = v0;" in stream.getvalue()
Loading

0 comments on commit b02c57c

Please sign in to comment.