Skip to content

Commit b8229e2

Browse files
committed
create+delete temp files in run()
1 parent 68da145 commit b8229e2

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

src/instructlab/eval/unitxt.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,29 @@ def __init__(
3939
model_path,
4040
unitxt_recipe: str,
4141
):
42-
task,tasks_dir = self.prepare_unitxt_files(unitxt_recipe)
42+
task = self.assign_task_name()
43+
tasks_dir = self.assign_tasks_dir(task)
4344
super().__init__(
4445
model_path = model_path,
4546
tasks_dir = tasks_dir,
4647
tasks = [task],
4748
few_shots = 0
4849
)
50+
self.unitxt_recipe = unitxt_recipe
4951

50-
def prepare_unitxt_files(self, unitxt_recipe)->tuple:
51-
temp_task = str(uuid4())
52-
temp_tasks_dir = f'{TEMP_DIR_PREFIX}_{temp_task}'
53-
yaml_file = os.path.join(temp_tasks_dir,f"{temp_task}.yaml")
54-
create_unitxt_pointer(temp_tasks_dir)
55-
create_unitxt_yaml(yaml_file=yaml_file, unitxt_recipe=unitxt_recipe, task_name=temp_task)
56-
return temp_task,temp_tasks_dir
52+
def assign_tasks_dir(self, task):
53+
return f'{TEMP_DIR_PREFIX}_{task}'
5754

58-
def remove_temp_files(self):
55+
def assign_task_name(self):
56+
return str(uuid4())
57+
58+
def prepare_unitxt_files(self)->tuple:
59+
task = self.tasks[0]
60+
yaml_file = os.path.join(self.tasks_dir,f"{task}.yaml")
61+
create_unitxt_pointer(self.tasks_dir)
62+
create_unitxt_yaml(yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=task)
63+
64+
def remove_unitxt_files(self):
5965
if self.tasks_dir.startswith(TEMP_DIR_PREFIX): #to avoid unintended deletion if this class is inherited
6066
shutil.rmtree(self.tasks_dir)
6167
else:
@@ -69,6 +75,7 @@ def run(self,server_url: str | None = None) -> tuple:
6975
overall_scores Average scores for the task group
7076
individual_scores Individual scores for each task in the task group
7177
"""
78+
self.prepare_unitxt_files()
7279
logger.debug(locals())
7380
os.environ["TOKENIZERS_PARALLELISM"] = "true"
7481
results = self._run_mmlu(server_url=server_url, return_all_results=True)
@@ -89,7 +96,7 @@ def run(self,server_url: str | None = None) -> tuple:
8996
logger.error(e)
9097
logger.error(e.__traceback__)
9198
instance_scores = None
92-
self.remove_temp_files()
99+
self.remove_unitxt_files()
93100
return global_scores,instance_scores
94101

95102

@@ -101,12 +108,12 @@ def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
101108
}
102109
with open(yaml_file, 'w') as file:
103110
yaml.dump(data, file, default_flow_style=False)
104-
logger.info(f"task {task} unitxt recipe written to {yaml_file}")
111+
logger.debug(f"task {task} unitxt recipe written to {yaml_file}")
105112

106113
def create_unitxt_pointer(tasks_dir):
107114
class_line = "class: !function " + task.__file__.replace("task.py", "task.Unitxt")
108115
output_file = os.path.join(tasks_dir,'unitxt')
109116
os.makedirs(os.path.dirname(output_file), exist_ok=True)
110117
with open(output_file, 'w') as f:
111118
f.write(class_line)
112-
logger.info(f"Unitxt task pointer written to {output_file}")
119+
logger.debug(f"Unitxt task pointer written to {output_file}")

0 commit comments

Comments
 (0)