Skip to content

Commit

Permalink
Add tests for Torch to Linalg backend pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD authored Jun 4, 2024
2 parents 72bcc30 + 24c1d2b commit 6771e04
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{verify=0})' -split-input-file %s | FileCheck %s

// CHECK: func.func @tosa
func.func @tosa(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: tosa.abs
%1 = tosa.abs %arg0 : (tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// -----

// CHECK: func.func @torch_gemm
func.func @torch_gemm(%arg0: tensor<?x3xf32>, %arg1: tensor<3x?xf32>, %arg2: tensor<?x?xf32>) -> (tensor<?x?xf32> {onnx.name = "gemm"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<?x3xf32> -> !torch.vtensor<[?,3],f32>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xf32> -> !torch.vtensor<[3,?],f32>
%2 = torch_c.from_builtin_tensor %arg2 : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%3 = torch.aten.mm %0, %1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32>
%4 = torch.aten.add.Tensor %3, %2, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
%5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
%6 = tosa.abs %5 : (tensor<?x?xf32>) -> tensor<?x?xf32>
return %6 : tensor<?x?xf32>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{use-mlprogram=0})' -split-input-file %s | FileCheck %s
// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{use-mlprogram=1})' -split-input-file %s | FileCheck --check-prefix=YES-CHECK %s

// CHECK-NOT: ml_program.global{{.*}}@global_seed
// YES-CHECK: ml_program.global{{.*}}@global_seed
// CHECK: func.func @torch_gemm
func.func @torch_gemm(%arg0: tensor<?x3xf32>, %arg1: tensor<3x?xf32>, %arg2: tensor<?x?xf32>) -> (tensor<?x?xf32> {onnx.name = "gemm"}) attributes {torch.onnx_meta.opset_version = 19 : si64} {
%0 = torch_c.from_builtin_tensor %arg0 : tensor<?x3xf32> -> !torch.vtensor<[?,3],f32>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xf32> -> !torch.vtensor<[3,?],f32>
%2 = torch_c.from_builtin_tensor %arg2 : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%3 = torch.aten.mm %0, %1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32>
%4 = torch.aten.add.Tensor %3, %2, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
%5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}

0 comments on commit 6771e04

Please sign in to comment.