@@ -39,23 +39,29 @@ def __init__(
39
39
model_path ,
40
40
unitxt_recipe : str ,
41
41
):
42
- task ,tasks_dir = self .prepare_unitxt_files (unitxt_recipe )
42
+ task = self .assign_task_name ()
43
+ tasks_dir = self .assign_tasks_dir (task )
43
44
super ().__init__ (
44
45
model_path = model_path ,
45
46
tasks_dir = tasks_dir ,
46
47
tasks = [task ],
47
48
few_shots = 0
48
49
)
50
+ self .unitxt_recipe = unitxt_recipe
49
51
50
- def prepare_unitxt_files (self , unitxt_recipe )-> tuple :
51
- temp_task = str (uuid4 ())
52
- temp_tasks_dir = f'{ TEMP_DIR_PREFIX } _{ temp_task } '
53
- yaml_file = os .path .join (temp_tasks_dir ,f"{ temp_task } .yaml" )
54
- create_unitxt_pointer (temp_tasks_dir )
55
- create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = unitxt_recipe , task_name = temp_task )
56
- return temp_task ,temp_tasks_dir
52
+ def assign_tasks_dir (self , task ):
53
+ return f'{ TEMP_DIR_PREFIX } _{ task } '
57
54
58
- def remove_temp_files (self ):
55
+ def assign_task_name (self ):
56
+ return str (uuid4 ())
57
+
58
+ def prepare_unitxt_files (self )-> tuple :
59
+ task = self .tasks [0 ]
60
+ yaml_file = os .path .join (self .tasks_dir ,f"{ task } .yaml" )
61
+ create_unitxt_pointer (self .tasks_dir )
62
+ create_unitxt_yaml (yaml_file = yaml_file , unitxt_recipe = self .unitxt_recipe , task_name = task )
63
+
64
+ def remove_unitxt_files (self ):
59
65
if self .tasks_dir .startswith (TEMP_DIR_PREFIX ): #to avoid unintended deletion if this class is inherited
60
66
shutil .rmtree (self .tasks_dir )
61
67
else :
@@ -69,6 +75,7 @@ def run(self,server_url: str | None = None) -> tuple:
69
75
overall_scores Average scores for the task group
70
76
individual_scores Individual scores for each task in the task group
71
77
"""
78
+ self .prepare_unitxt_files ()
72
79
logger .debug (locals ())
73
80
os .environ ["TOKENIZERS_PARALLELISM" ] = "true"
74
81
results = self ._run_mmlu (server_url = server_url , return_all_results = True )
@@ -89,7 +96,7 @@ def run(self,server_url: str | None = None) -> tuple:
89
96
logger .error (e )
90
97
logger .error (e .__traceback__ )
91
98
instance_scores = None
92
- self .remove_temp_files ()
99
+ self .remove_unitxt_files ()
93
100
return global_scores ,instance_scores
94
101
95
102
@@ -101,12 +108,12 @@ def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
101
108
}
102
109
with open (yaml_file , 'w' ) as file :
103
110
yaml .dump (data , file , default_flow_style = False )
104
- logger .info (f"task { task } unitxt recipe written to { yaml_file } " )
111
+ logger .debug (f"task { task } unitxt recipe written to { yaml_file } " )
105
112
106
113
def create_unitxt_pointer (tasks_dir ):
107
114
class_line = "class: !function " + task .__file__ .replace ("task.py" , "task.Unitxt" )
108
115
output_file = os .path .join (tasks_dir ,'unitxt' )
109
116
os .makedirs (os .path .dirname (output_file ), exist_ok = True )
110
117
with open (output_file , 'w' ) as f :
111
118
f .write (class_line )
112
- logger .info (f"Unitxt task pointer written to { output_file } " )
119
+ logger .debug (f"Unitxt task pointer written to { output_file } " )
0 commit comments