Skip to content

Commit

Permalink
Update Triton to c75c6b034756629b891e7b2df406f634552331d5 (#223)
Browse files Browse the repository at this point in the history
Trying to fix #178

This PR includes cosmetic changes due to LLVM API change, fix for link
error, lit test update and adding unsupported tests in conftest.py

---------

Co-authored-by: Zhaoshi Zheng <[email protected]>
  • Loading branch information
zhaoshiz and Zhaoshi Zheng authored Feb 6, 2025
1 parent 36c6551 commit 6f718b7
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 47 deletions.
2 changes: 1 addition & 1 deletion backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def parse_options(self, opts) -> Any:
args.update({k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts})
return CPUOptions(**args)

def get_codegen_implementation(self):
def get_codegen_implementation(self, options):
codegen_fns = {"min_dot_size": lambda lhsType, rhsType: (1, 1, 1)}
return codegen_fns

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ struct MakeTensorPtrConverter
SmallVector<Value> newOffsets;
for (auto [offset, stride] :
llvm::zip(pointerState.offsets, pointerState.strides)) {
auto mulOp = rewriter.create<arith::MulIOp>(loc, offset.get<Value>(),
stride.get<Value>());
auto mulOp = rewriter.create<arith::MulIOp>(loc, cast<Value>(offset),
cast<Value>(stride));
newOffsets.push_back(mulOp.getResult());
}

Expand Down Expand Up @@ -435,7 +435,7 @@ struct LoadConverter : public OpConversionPattern<triton::LoadOp> {
Value dimi = dyn_cast<Value>(mstate.dims[i]);
if (!dimi) {
dimi = rewriter.create<arith::ConstantOp>(
loc, cast<IntegerAttr>(mstate.dims[i].get<Attribute>()));
loc, cast<IntegerAttr>(cast<Attribute>(mstate.dims[i])));
}

auto cmpOp = rewriter.create<arith::CmpIOp>(
Expand Down Expand Up @@ -1236,9 +1236,10 @@ struct ReduceConverter : public OpConversionPattern<triton::ReduceOp> {
}

bool requiresF32Conversion(const Type elemType, Operation *redOp) const {
unsigned width =
cast<FloatType>(Float32Type::get(elemType.getContext())).getWidth();
return isa<FloatType>(elemType) &&
elemType.getIntOrFloatBitWidth() <
Float32Type::get(elemType.getContext()).getWidth() &&
elemType.getIntOrFloatBitWidth() < width &&
isa<arith::AddFOp>(redOp);
}

Expand Down
6 changes: 3 additions & 3 deletions lib/Analysis/OpFoldResultUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
namespace mlir {

std::optional<int64_t> getIntAttr(const OpFoldResult ofr) {
if (ofr.is<Attribute>() && isa<IntegerAttr>(ofr.get<Attribute>()))
return dyn_cast<IntegerAttr>(ofr.get<Attribute>()).getInt();
if (isa<Attribute>(ofr) && isa<IntegerAttr>(cast<Attribute>(ofr)))
return dyn_cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();

return std::nullopt;
}
Expand Down Expand Up @@ -185,7 +185,7 @@ OpFoldResult mulOFRValue(const OpFoldResult lhs, const Value rhs,

// 2. if lhs is not constant
assert(!lhsIntAttr);
auto mulOp = b.create<arith::MulIOp>(loc, lhs.get<Value>(), rhs);
auto mulOp = b.create<arith::MulIOp>(loc, cast<Value>(lhs), rhs);
return mulOp.getResult();
}

Expand Down
12 changes: 6 additions & 6 deletions lib/Analysis/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,12 +862,12 @@ void PtrAnalysis::rewriteAdvanceOp(
op.getLoc(), rewriter.getIndexAttr(0));
offsetValue = constOp.getResult();
} else {
offsetValue = offset.get<Value>();
offsetValue = cast<Value>(offset);
}
auto castOp = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), increment);
auto mulOp = rewriter.create<arith::MulIOp>(loc, castOp.getResult(),
stride.get<Value>());
cast<Value>(stride));
auto addOp =
rewriter.create<arith::AddIOp>(loc, mulOp.getResult(), offsetValue);
newOffsets.push_back(addOp.getResult());
Expand Down Expand Up @@ -999,15 +999,15 @@ void PtrAnalysis::rewriteYieldOp(
op.getLoc(), rewriter.getIndexAttr(0));
operands.push_back(constOp.getResult());
} else {
operands.push_back(s.get<Value>());
operands.push_back(cast<Value>(s));
}
}

for (auto s : state.strides) {
assert(!getIntAttr(s) && "PtrState strides for yield within for "
"loop not expected to be "
"attribute.");
operands.push_back(s.get<Value>());
operands.push_back(cast<Value>(s));
}
}

Expand Down Expand Up @@ -1171,7 +1171,7 @@ void PtrAnalysis::rewriteForOp(
newInitArgs.push_back(constOp.getResult());
state.offsets[j] = constOp.getResult();
} else {
newInitArgs.push_back(s.get<Value>());
newInitArgs.push_back(cast<Value>(s));
}
}

Expand All @@ -1183,7 +1183,7 @@ void PtrAnalysis::rewriteForOp(
newInitArgs.push_back(constOp.getResult());
state.strides[j] = constOp.getResult();
} else {
newInitArgs.push_back(s.get<Value>());
newInitArgs.push_back(cast<Value>(s));
}
}

Expand Down
8 changes: 4 additions & 4 deletions lib/AnalysisStructured/PtrAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,12 +793,12 @@ LogicalResult PtrAnalysis::rewriteAdvanceOp(triton::AdvanceOp op) {
loc, builder.getIndexAttr(offsetIntAttr.value()));
offsetValue = constOp.getResult();
} else {
offsetValue = offset.get<Value>();
offsetValue = cast<Value>(offset);
}
auto castOp = builder.create<arith::IndexCastOp>(
loc, builder.getIndexType(), increment);
auto mulOp = builder.create<arith::MulIOp>(loc, castOp.getResult(),
stride.get<Value>());
cast<Value>(stride));
auto addOp =
builder.create<arith::AddIOp>(loc, mulOp.getResult(), offsetValue);
newOffsets.push_back(addOp.getResult());
Expand Down Expand Up @@ -1029,7 +1029,7 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
op.getLoc(), builder.getIndexAttr(sIntAttr.value()));
replacements.push_back(constOp.getResult());
} else {
replacements.push_back(s.get<Value>());
replacements.push_back(cast<Value>(s));
}
}

Expand All @@ -1040,7 +1040,7 @@ PtrAnalysis::rewriteGetStructuredStateOp(tts::GetStructuredStateOp op) {
op.getLoc(), builder.getIndexAttr(sIntAttr.value()));
replacements.push_back(constOp.getResult());
} else {
replacements.push_back(s.get<Value>());
replacements.push_back(cast<Value>(s));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class TritonArithToLinalgPass

tensor::populateDecomposeTensorConcatPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
return failure();
}
return success();
Expand All @@ -103,7 +103,7 @@ class TritonArithToLinalgPass
{
RewritePatternSet patterns(&getContext());
populateTritonArithToLinalgCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonToLinalg/TritonToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TritonToLinalgPass : public TritonToLinalgBase<TritonToLinalgPass> {
{
RewritePatternSet patterns(&getContext());
populateTritonToLinalgCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
9 changes: 7 additions & 2 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,17 @@ def device(request):
# tt.gather not supported yet
"test_gather",
"test_gather_warp_shuffle",
# device 'cpu' does not have 'index
# device 'cpu' does not have 'index'
"test_zero_strided_tensors",
# hard-coded with 'ttg' attributes
"test_convert_mma2mma",
"test_local_load_store",
"test_local_load_store_mma"
"test_local_load_store_mma",
"test_convert_warp_local",
# hard-code to use 'cuda' device
"test_scan_1d",
"test_tma_load_block_shape_err",
"test_tma_store_block_shape_err"
}

# probably different version of MLIR on the nightly build machine is complaining
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ module {
%subview = memref.subview %reinterpret_cast[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
%subview_0 = memref.subview %alloc[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
memref.copy %subview, %subview_0 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32>
%10 = bufferization.to_tensor %alloc restrict writable : memref<1024xf32> to tensor<1024xf32>
%reinterpret_cast_1 = memref.reinterpret_cast %1 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%alloc_2 = memref.alloc() : memref<1024xf32>
%subview_3 = memref.subview %reinterpret_cast_1[0] [%9] [1] : memref<1024xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>>
%subview_4 = memref.subview %alloc_2[0] [%9] [1] : memref<1024xf32> to memref<?xf32, strided<[1]>>
memref.copy %subview_3, %subview_4 : memref<?xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1]>>
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32>
%11 = bufferization.to_tensor %alloc_2 restrict writable : memref<1024xf32> to tensor<1024xf32>
%12 = arith.addf %10, %11 : tensor<1024xf32>
%reinterpret_cast_5 = memref.reinterpret_cast %0 to offset: [%5], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
%extracted_slice = tensor.extract_slice %12[0] [%9] [1] : tensor<1024xf32> to tensor<?xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module {
%8 = tt.splat %arg3 : i32 -> tensor<1024xi32>
%9 = arith.cmpi slt, %7, %8 : tensor<1024xi32>
%cast = memref.cast %2 : memref<*xf32> to memref<?xf32>
%10 = bufferization.to_tensor %cast restrict : memref<?xf32>
%10 = bufferization.to_tensor %cast restrict : memref<?xf32> to tensor<?xf32>
%11 = tensor.empty() : tensor<1024xf32>
%12 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%11 : tensor<1024xf32>) {
^bb0(%in: i32, %in_2: i1, %out: f32):
Expand All @@ -30,7 +30,7 @@ module {
linalg.yield %17 : f32
} -> tensor<1024xf32>
%cast_0 = memref.cast %1 : memref<*xf32> to memref<?xf32>
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32>
%13 = bufferization.to_tensor %cast_0 restrict : memref<?xf32> to tensor<?xf32>
%14 = tensor.empty() : tensor<1024xf32>
%15 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%7, %9 : tensor<1024xi32>, tensor<1024xi1>) outs(%14 : tensor<1024xf32>) {
^bb0(%in: i32, %in_2: i1, %out: f32):
Expand Down
7 changes: 3 additions & 4 deletions test/Conversion/TritonToLinalg/addptr_for_accumulation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ module {
// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<4x256xbf16>
// CHECK: %[[VAL_17:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
// CHECK: %[[VAL_18:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_17]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
// CHECK: %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]], %[[VAL_24:.*]] = %[[VAL_12]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index, index) {
// CHECK: %[[VAL_19:.*]]:3 = scf.for %[[VAL_20:.*]] = %[[VAL_12]] to %[[VAL_11]] step %[[VAL_10]] iter_args(%[[VAL_21:.*]] = %[[VAL_16]], %[[VAL_22:.*]] = %[[VAL_18]], %[[VAL_23:.*]] = %[[VAL_17]]) -> (tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index) {
// CHECK: %[[VAL_25:.*]] = memref.alloc() : memref<4x256xbf16>
// CHECK: memref.copy %[[VAL_22]], %[[VAL_25]] : memref<4x256xbf16, strided<[?, ?], offset: ?>> to memref<4x256xbf16>
// CHECK: %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_25]] restrict writable : memref<4x256xbf16>
Expand All @@ -81,10 +81,9 @@ module {
// CHECK: %[[VAL_31:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : bf16
// CHECK: linalg.yield %[[VAL_31]] : bf16
// CHECK: } -> tensor<4x256xbf16>
// CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index
// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_32]], %[[VAL_24]] : index
// CHECK: %[[VAL_33:.*]] = arith.addi %[[VAL_23]], %[[VAL_10]] : index
// CHECK: %[[VAL_34:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_33]]], sizes: [4, 256], strides: {{\[}}%[[VAL_9]], %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[?, ?], offset: ?>>
// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]], %[[VAL_12]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index, index
// CHECK: scf.yield %[[VAL_35:.*]], %[[VAL_34]], %[[VAL_33]] : tensor<4x256xbf16>, memref<4x256xbf16, strided<[?, ?], offset: ?>>, index
// CHECK: }
// CHECK: %[[VAL_36:.*]] = arith.index_cast %[[VAL_3]] : i32 to index
// CHECK: %[[VAL_37:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: {{\[}}%[[VAL_36]]], sizes: [4, 256], strides: [1, %[[VAL_8]]] : memref<*xbf16> to memref<4x256xbf16, strided<[1, ?], offset: ?>>
Expand Down
25 changes: 11 additions & 14 deletions test/Conversion/TritonToLinalg/block_ptr_advance.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ module {
// CHECK: module {
// CHECK: func.func @matmul_kernel_with_block_pointers_01234567891011(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: memref<*xbf16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32, %arg18: i32, %arg19: i32) {
// CHECK: %c64 = arith.constant 64 : index
// CHECK: %c0 = arith.constant 0 : index
// CHECK: %c256_i32 = arith.constant 256 : i32
// CHECK: %c0_i32 = arith.constant 0 : i32
// CHECK: %c64_i32 = arith.constant 64 : i32
Expand All @@ -51,7 +50,7 @@ module {
// CHECK: %7 = arith.addi %5, %6 : index
// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%7], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %reinterpret_cast_0 = memref.reinterpret_cast %arg0 to offset: [%5], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %8:7 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %c0, %arg26 = %5, %arg27 = %c0) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index) : i32 {
// CHECK: %8:5 = scf.for %arg20 = %c0_i32 to %arg5 step %c64_i32 iter_args(%arg21 = %1, %arg22 = %reinterpret_cast, %arg23 = %reinterpret_cast_0, %arg24 = %7, %arg25 = %5) -> (tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index) : i32 {
// CHECK: %alloc = memref.alloc() : memref<128x64xbf16>
// CHECK: memref.copy %arg22, %alloc : memref<128x64xbf16, strided<[?, ?], offset: ?>> to memref<128x64xbf16>
// CHECK: %17 = bufferization.to_tensor %alloc restrict writable : memref<128x64xbf16>
Expand All @@ -60,23 +59,21 @@ module {
// CHECK: %18 = bufferization.to_tensor %alloc_2 restrict writable : memref<128x64xbf16>
// CHECK: %19 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%17, %18 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%17 : tensor<128x64xbf16>) {
// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16):
// CHECK: %27 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %27 : bf16
// CHECK: %25 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %25 : bf16
// CHECK: } -> tensor<128x64xbf16>
// CHECK: %20 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg21, %19 : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%arg21 : tensor<128x64xbf16>) {
// CHECK: ^bb0(%in: bf16, %in_5: bf16, %out: bf16):
// CHECK: %27 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %27 : bf16
// CHECK: %25 = arith.addf %in, %in_5 : bf16
// CHECK: linalg.yield %25 : bf16
// CHECK: } -> tensor<128x64xbf16>
// CHECK: %21 = arith.muli %4, %c64 : index
// CHECK: %22 = arith.addi %21, %arg25 : index
// CHECK: %23 = arith.addi %arg24, %22 : index
// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%23], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %24 = arith.muli %3, %c64 : index
// CHECK: %25 = arith.addi %24, %arg26 : index
// CHECK: %26 = arith.addi %25, %arg27 : index
// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%26], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %23, %c0, %26, %c0 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index, index, index
// CHECK: %22 = arith.addi %arg24, %21 : index
// CHECK: %reinterpret_cast_3 = memref.reinterpret_cast %arg0 to offset: [%22], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: %23 = arith.muli %3, %c64 : index
// CHECK: %24 = arith.addi %23, %arg25 : index
// CHECK: %reinterpret_cast_4 = memref.reinterpret_cast %arg0 to offset: [%24], sizes: [128, 64], strides: [%3, %4] : memref<*xbf16> to memref<128x64xbf16, strided<[?, ?], offset: ?>>
// CHECK: scf.yield %20, %reinterpret_cast_3, %reinterpret_cast_4, %22, %24 : tensor<128x64xbf16>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, memref<128x64xbf16, strided<[?, ?], offset: ?>>, index, index
// CHECK: }
// CHECK: %9 = arith.muli %arg13, %c256_i32 : i32
// CHECK: %10 = arith.index_cast %arg12 : i32 to index
Expand Down
1 change: 1 addition & 0 deletions tools/triton-shared-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ target_link_libraries(triton-shared-opt PRIVATE
TritonAnalysis
TritonTransforms
TritonGPUTransforms
TritonTestDialectTritonGPU
TritonSharedAnalysis
${dialect_libs}
${conversion_libs}
Expand Down
2 changes: 1 addition & 1 deletion triton
Submodule triton updated 292 files

0 comments on commit 6f718b7

Please sign in to comment.