Skip to content

Commit

Permalink
handle batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
almazgimaev committed Oct 22, 2024
1 parent 6e1602f commit ac0357a
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/ui/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,22 @@ def run_evaluation(
res_dir = f"/model-benchmark/{project.id}_{project.name}/{task_dir}/"
res_dir = g.api.storage.get_free_dir_name(g.team_id, res_dir)

bm.run_evaluation(model_session=g.session_id)
session_info = g.session.get_session_info()
support_batch_inference = session_info.get("batch_inference_support", False)
max_batch_size = session_info.get("max_batch_size")
batch_size = 16
if not support_batch_inference:
batch_size = 1
if max_batch_size is not None:
batch_size = min(max_batch_size, 16)
bm.run_evaluation(model_session=g.session_id, batch_size=batch_size)

try:
batch_sizes = (1, 8, 16)
session_info = g.session.get_session_info()
support_batch_inference = session_info.get("batch_inference_support", False)
if not support_batch_inference:
batch_sizes = (1,)
elif max_batch_size is not None:
batch_sizes = tuple([bs for bs in batch_sizes if bs <= max_batch_size])
bm.run_speedtest(g.session_id, g.project_id, batch_sizes=batch_sizes)
sec_pbar.hide()
bm.upload_speedtest_results(res_dir + "/speedtest/")
Expand Down

0 comments on commit ac0357a

Please sign in to comment.