Skip to content

Commit

Permalink
Training script: Added support for batch predict for ranking, and sav…
Browse files Browse the repository at this point in the history
…ing to file. Parsing properly list CLI args
  • Loading branch information
gabrielspmoreira committed Apr 10, 2023
1 parent 932f366 commit 44eb8ad
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 66 deletions.
56 changes: 49 additions & 7 deletions examples/quick_start/scripts/ranking/args_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class MtlArgsPrefix(Enum):


INT_LIST_ARGS = ["mlp_layers", "expert_mlp_layers", "tower_layers"]
STR_LIST_ARGS = [
"tasks",
"tasks_sample_space",
"predict_keep_cols",
"wnd_ignore_combinations",
]


def str2bool(v):
Expand All @@ -37,12 +43,12 @@ def parse_dynamic_args(dyn_args):
return dyn_args_dict


def parse_int_list_arg(value):
# Used because autobench can't provide empty string ("") as argument
if value == "None":
def parse_list_arg(value, vtype=str):
# Used to allow providing empty string ("") as command line argument
if value is None or value == "None":
value = ""

alist = list([int(v.strip()) for v in value.split(",") if v != ""])
alist = list([vtype(v.strip()) for v in value.split(",") if v != ""])
return alist


Expand All @@ -65,7 +71,10 @@ def parse_arguments():

# Parsing str args that contains lists of ints
for a in INT_LIST_ARGS:
new_args[a] = parse_int_list_arg(new_args[a])
new_args[a] = parse_list_arg(new_args[a], vtype=int)

for a in STR_LIST_ARGS:
new_args[a] = parse_list_arg(new_args[a])

# logging.info(f"ARGUMENTS: {new_args}")

Expand All @@ -82,6 +91,39 @@ def build_arg_parser():
parser.add_argument("--output_path", default="/results/", help="")
parser.add_argument("--save_trained_model_path", default=None, help="")

parser.add_argument(
"--predict",
type=str2bool,
nargs="?",
const=True,
default=False,
help="If enabled, instead the dataset provided "
"will be used for prediction (instead of evaluation)."
"The prediction scores for the dataset in eval_path will "
"be saved to a file defined in --predict_output_path, "
"according to the --predict_output_format choice.",
)
parser.add_argument(
"--predict_keep_cols",
default=None,
help="Comma-separated list of columns to keep in the output "
"prediction file. If no columns is provided, all columns "
"are kept together with the prediction scores.",
)
parser.add_argument(
"--predict_output_path",
default=None,
help="If provided the prediction scores will be saved to this path. "
"Otherwise, files will be saved to output_path/predictions",
)

parser.add_argument(
"--predict_output_format",
default="parquet",
choices=["parquet", "csv", "tsv"],
help="Format of the output prediction file.",
)

# Tasks
parser.add_argument(
"--tasks",
Expand Down Expand Up @@ -189,9 +231,9 @@ def build_arg_parser():
parser.add_argument("--epochs", default=1, type=int, help="")
parser.add_argument("--optimizer", default="adam", choices=["adagrad", "adam"], help="")

parser.add_argument("--train_metrics_steps", default=50, type=int, help="")
parser.add_argument("--train_metrics_steps", default=10, type=int, help="")
parser.add_argument("--metrics_log_frequency", default=50, type=int, help="")
parser.add_argument("--validation_steps", default=0, type=int, help="")
parser.add_argument("--validation_steps", default=10, type=int, help="")

parser.add_argument("--random_seed", default=42, type=int, help="")
parser.add_argument("--train_steps_per_epoch", type=int, help="")
Expand Down
6 changes: 3 additions & 3 deletions examples/quick_start/scripts/ranking/mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def get_mtl_prediction_tasks(targets, args):
tasks_pos_class_weights = get_mtl_positive_class_weights(targets, args)

if args.tasks_sample_space:
if len(args.tasks.split(",")) != len(args.tasks_sample_space.split(",")):
if len(args.tasks) != len(args.tasks_sample_space):
raise ValueError(
"If --tasks_sample_space is provided, the list of tasks sample "
"(separated by ',') need to match the length of the list "
Expand All @@ -91,9 +91,9 @@ def get_mtl_prediction_tasks(targets, args):
"then you can use empty string ('') for that task. "
"For example: --tasks=click,like --tasks_sample_space=,click"
)
tasks_space = dict(zip(args.tasks.split(","), args.tasks_sample_space.split(",")))
tasks_space = dict(zip(args.tasks, args.tasks_sample_space))
else:
tasks_space = {t: None for t in args.tasks.split(",")}
tasks_space = {t: None for t in args.tasks}

prediction_tasks = []
if Task.BINARY_CLASSIFICATION in targets:
Expand Down
2 changes: 1 addition & 1 deletion examples/quick_start/scripts/ranking/ranking_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_wide_and_deep_model(schema, args, prediction_tasks):

ignore_combinations = None
if args.wnd_ignore_combinations:
ignore_combinations = [x.split(":") for x in args.wnd_ignore_combinations.split(",")]
ignore_combinations = [x.split(":") for x in args.wnd_ignore_combinations]

wide_preprocess = [
# One-hot features
Expand Down
Loading

0 comments on commit 44eb8ad

Please sign in to comment.