From c3e35fdf1a33c992ce734386f63f8a307e3e042b Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Wed, 27 Mar 2024 12:37:49 -0700 Subject: [PATCH] [E2E full NNs] Add some cnns for verification This commit adds Resnet, Resnext and Vgg e2e tests. Signed-off-by: Dmitrii Makarenko --- .../linalg_on_tensors_backends/refbackend.py | 2 +- .../torch_mlir_e2e_test/test_suite/mlp.py | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index a3d31ca728bc..f33db6a00dce 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -210,7 +210,7 @@ def compile(self, imported_module: Module, ir_file: str = None): run_pipeline_with_repro_report( imported_module, LOWERING_PIPELINE, "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", - enable_ir_printing=False, ir_file=ir_file) + enable_ir_printing=False, ir_dump_file=ir_file) return imported_module def load(self, module) -> RefBackendInvoker: diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py index b568bd4e28e0..8863fe560d73 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/mlp.py @@ -140,3 +140,57 @@ def MLP_basic(module, tu: TestUtils): out = module.forward(test_input) print("[test body] out shape: ", out.size()) + + +from torchvision.models import ( + vgg16, + resnet18, + resnet50, + resnext50_32x4d, + resnext101_32x8d, + densenet121, + efficientnet_v2_m, + mobilenet_v3_large, +) + + +def ResNext(): + torch.manual_seed(0) + model = resnext50_32x4d() + model.eval() + return model + + +def ResNet(): + torch.manual_seed(0) + model = resnet50() + model.eval() + return model + + +def Vgg(): + torch.manual_seed(0) + model = vgg16() + model.eval() + return model + + +@register_test_case(module_factory=lambda: ResNext()) +def ResNext_basic(module, tu: TestUtils): + # out = module.forward(tu.randint(1, 11, high=13000)) + out = module.forward(tu.rand(1, 3, 224, 224)) + # model.forward(input_ids=input_ids.input_ids, attention_mask=input_ids.attention_mask, output_hidden_states=False, use_cache=False) + # print("gen tokens: ", gen_tokens) + return out + + +@register_test_case(module_factory=lambda: ResNet()) +def ResNet_basic(module, tu: TestUtils): + out = module.forward(tu.rand(1, 3, 224, 224)) + return out + + +@register_test_case(module_factory=lambda: Vgg()) +def Vgg_basic(module, tu: TestUtils): + out = module.forward(tu.rand(1, 3, 224, 224)) + return out