Skip to content

Commit

Permalink
Kaggarwal/fix generation configs (Azure#1422)
Browse files Browse the repository at this point in the history
* Fix generation config and run properties

* revert run properties fix
  • Loading branch information
aggarwal-k authored Oct 9, 2023
1 parent 66c38c9 commit 7cf8ce2
Showing 1 changed file with 4 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -428,27 +428,25 @@ def model_selector(args: Namespace):
for key2 in mlflow_data["flavors"][key]:
if key2 == "generator_config" and args.task_name == "TextGeneration":
generator_config = mlflow_data["flavors"][key]["generator_config"]
mlflow_ftconf_data.update(
{
mlflow_ftconf_data_temp = {
"load_config_kwargs": copy.deepcopy(generator_config),
"mlflow_ft_conf": {
"mlflow_hftransformers_misc_conf": {
"generator_config": copy.deepcopy(generator_config),
},
},
}
)
mlflow_ftconf_data = deep_update(mlflow_ftconf_data_temp, mlflow_ftconf_data)
elif key2 == "model_hf_load_kwargs":
model_hf_load_kwargs = mlflow_data["flavors"][key]["model_hf_load_kwargs"]
mlflow_ftconf_data.update(
{
mlflow_ftconf_data_temp = {
"mlflow_ft_conf": {
"mlflow_hftransformers_misc_conf": {
"model_hf_load_kwargs": copy.deepcopy(model_hf_load_kwargs),
},
},
}
)
mlflow_ftconf_data = deep_update(mlflow_ftconf_data_temp, mlflow_ftconf_data)
ft_config_data = deep_update(mlflow_ftconf_data, ft_config_data)
logger.info(f"Updated FT config data - {ft_config_data}")
except Exception:
Expand Down

0 comments on commit 7cf8ce2

Please sign in to comment.