From ac0357a35fcd1fcae0eaf5a5da75d48acd1b9e7c Mon Sep 17 00:00:00 2001 From: almaz Date: Tue, 22 Oct 2024 16:05:31 +0200 Subject: [PATCH] handle batch size --- src/ui/evaluation.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/ui/evaluation.py b/src/ui/evaluation.py index 3c0ae82..e29067d 100644 --- a/src/ui/evaluation.py +++ b/src/ui/evaluation.py @@ -90,14 +90,22 @@ def run_evaluation( res_dir = f"/model-benchmark/{project.id}_{project.name}/{task_dir}/" res_dir = g.api.storage.get_free_dir_name(g.team_id, res_dir) - bm.run_evaluation(model_session=g.session_id) + session_info = g.session.get_session_info() + support_batch_inference = session_info.get("batch_inference_support", False) + max_batch_size = session_info.get("max_batch_size") + batch_size = 16 + if not support_batch_inference: + batch_size = 1 + if max_batch_size is not None: + batch_size = min(max_batch_size, 16) + bm.run_evaluation(model_session=g.session_id, batch_size=batch_size) try: batch_sizes = (1, 8, 16) - session_info = g.session.get_session_info() - support_batch_inference = session_info.get("batch_inference_support", False) if not support_batch_inference: batch_sizes = (1,) + elif max_batch_size is not None: + batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size]) bm.run_speedtest(g.session_id, g.project_id, batch_sizes=batch_sizes) sec_pbar.hide() bm.upload_speedtest_results(res_dir + "/speedtest/")