From 4d9497973a51e24f3ee68e83674c313b07a2ed38 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Tue, 21 Jun 2022 16:39:39 -0400 Subject: [PATCH 01/10] add foundation of hugectr op --- merlin/systems/dag/ops/hugectr.py | 295 +++++++++++++++++ .../unit/common/parsers/benchmark_parsers.py | 178 ++++++++++ tests/unit/common/parsers/criteo_parsers.py | 139 ++++++++ tests/unit/common/parsers/rossmann_parsers.py | 77 +++++ tests/unit/common/utils.py | 150 +++++++++ tests/unit/systems/hugectr/__init__.py | 0 tests/unit/systems/hugectr/test_hugectr.py | 312 ++++++++++++++++++ 7 files changed, 1151 insertions(+) create mode 100644 merlin/systems/dag/ops/hugectr.py create mode 100644 tests/unit/common/parsers/benchmark_parsers.py create mode 100644 tests/unit/common/parsers/criteo_parsers.py create mode 100644 tests/unit/common/parsers/rossmann_parsers.py create mode 100644 tests/unit/common/utils.py create mode 100644 tests/unit/systems/hugectr/__init__.py create mode 100644 tests/unit/systems/hugectr/test_hugectr.py diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py new file mode 100644 index 000000000..e3f8e071f --- /dev/null +++ b/merlin/systems/dag/ops/hugectr.py @@ -0,0 +1,295 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import os +import pathlib + +import numpy as np +import tritonclient.grpc.model_config_pb2 as model_config +from google.protobuf import text_format + +from merlin.dag import ColumnSelector +from merlin.schema import ColumnSchema, Schema +from merlin.systems.dag.ops.operator import InferenceOperator + + +class HugeCTR(InferenceOperator): + """ + Creates an operator meant to house a HugeCTR model. + Allows the model to run as part of a merlin graph operations for inference. + """ + + def __init__( + self, + model, + max_batch_size=1024, + device_list=None, + hit_rate_threshold=None, + gpucache=None, + freeze_sparse=None, + gpucacheper=None, + label_dim=None, + slots=None, + cat_feature_num=None, + des_feature_num=None, + max_nnz=None, + embedding_vector_size=None, + embeddingkey_long_type=None, + ): + self.model = model + self.max_batch_size = max_batch_size or 1024 + self.device_list = device_list + + # if isinstance(model_or_path,(str, os.PathLike)): + # self.path = model_or_path + # elif isinstance(model_or_path, hugectr.Model): + # self.model = model_or_path + # else: + # raise ValueError( + # "Unsupported type for model_or_path. " + # "Must be pathlike or hugectr.Model" + # ) + + self.hugectr_params = dict( + hit_rate_threshold=hit_rate_threshold, + gpucache=gpucache, + freeze_sparse=freeze_sparse, + gpucacheper=gpucacheper, + label_dim=label_dim, + slots=slots, + cat_feature_num=cat_feature_num, + des_feature_num=des_feature_num, + max_nnz=max_nnz, + embedding_vector_size=embedding_vector_size, + embeddingkey_long_type=embeddingkey_long_type, + ) + + super().__init__() + + def compute_input_schema( + self, + root_schema: Schema, + parents_schema: Schema, + deps_schema: Schema, + selector: ColumnSelector, + ): + """_summary_ + + Parameters + ---------- + root_schema : Schema + The original schema to the graph. + parents_schema : Schema + A schema comprised of the output schemas of all parent nodes. + deps_schema : Schema + A concatenation of the output schemas of all dependency nodes. + selector : ColumnSelector + Sub selection of columns required to compute the input schema. + + Returns + ------- + Schema + A schema describing the inputs of the model. + """ + return Schema( + [ + ColumnSchema("DES", dtype=np.float32), + ColumnSchema("CATCOLUMN", dtype=np.int64), + ColumnSchema("ROWINDEX", dtype=np.int32), + ] + ) + + def compute_output_schema( + self, + input_schema: Schema, + col_selector: ColumnSelector, + prev_output_schema: Schema = None, + ): + """Return output schema of the model. + + Parameters + ---------- + input_schema : Schema + Schema representing inputs to the model + col_selector : ColumnSelector + list of columns to focus on from input schema + prev_output_schema : Schema, optional + The output schema of the previous node, by default None + + Returns + ------- + Schema + Schema describing the output of the model. + """ + return Schema([ColumnSchema("OUTPUT0", dtype=np.float32)]) + + def export(self, path, input_schema, output_schema, node_id=None, version=1): + """Create and export the required config files for the hugectr model. + + Parameters + ---------- + path : current path of the model + _description_ + input_schema : Schema + Schema describing inputs to model + output_schema : Schema + Schema describing outputs of model + node_id : int, optional + The node's position in execution chain, by default None + version : int, optional + The version of the model, by default 1 + + Returns + ------- + config + Dictionary representation of config file in memory. + """ + node_name = f"{node_id}_{self.export_name}" if node_id is not None else self.export_name + node_export_path = pathlib.Path(path) / node_name + node_export_path.mkdir(exist_ok=True) + + hugectr_model_path = pathlib.Path(node_export_path) / str(version) + hugectr_model_path.mkdir(exist_ok=True) + self.model.graph_to_json(graph_config_file=str(hugectr_model_path / "model.json")) + self.model.save_params_to_files(str(hugectr_model_path) + "/") + # generate config + # save artifacts to model repository (path) + # {node_id}_hugectr/config.pbtxt + # {node_id}_hugectr/1/ + model_name = "model" + dense_pattern = "*_dense_*.model" + dense_path = [ + os.path.join(hugectr_model_path, path.name) + for path in hugectr_model_path.glob(dense_pattern) + ][0] + sparse_pattern = "*_sparse_*.model" + sparse_paths = [ + os.path.join(hugectr_model_path, path.name) + for path in hugectr_model_path.glob(sparse_pattern) + ] + network_file = os.path.join(hugectr_model_path, f"{model_name}.json") + + config_dict = dict() + config_dict["supportlonglong"] = True + model = dict() + model["model"] = model_name + model["sparse_files"] = sparse_paths + model["dense_file"] = dense_path + model["network_file"] = network_file + model["num_of_worker_buffer_in_pool"] = 4 + model["num_of_refresher_buffer_in_pool"] = 1 + model["deployed_device_list"] = self.device_list + model["max_batch_size"] = (self.max_batch_size,) + model["default_value_for_each_table"] = [0.0] + model["hit_rate_threshold"] = 0.9 + model["gpucacheper"] = 0.5 + model["gpucache"] = True + model["cache_refresh_percentage_per_iteration"] = 0.2 + config_dict["models"] = [model] + + parameter_server_config_path = str(hugectr_model_path / "ps.json") + with open(parameter_server_config_path, "w") as f: + f.write(json.dumps(config_dict)) + + self.hugectr_params["config"] = parameter_server_config_path + config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size) + + with open(os.path.join(node_export_path, "config.pbtxt"), "w") as o: + text_format.PrintMessage(config, o) + + return config + + +def _hugectr_config(name, hugectr_params, max_batch_size=None): + """Create a config for a HugeCTR model. + + Parameters + ---------- + name : string + The name of the hugectr model. + hugectr_params : dictionary + Dictionary holding parameter values required by hugectr + max_batch_size : int, optional + The maximum batch size to be processed per batch, by an inference request, by default None + + Returns + ------- + config + Dictionary representation of hugectr config. + """ + config = model_config.ModelConfig(name=name, backend="hugectr", max_batch_size=max_batch_size) + + config.input.append( + model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]) + ) + + config.input.append( + model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1]) + ) + + config.input.append( + model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1]) + ) + + config.output.append( + model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1]) + ) + + config.instance_group.append(model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)) + + config_hugectr = model_config.ModelParameter(string_value=hugectr_params["config"]) + config.parameters["config"].CopyFrom(config_hugectr) + + gpucache_val = hugectr_params.get("gpucache", "true") + + gpucache = model_config.ModelParameter(string_value=gpucache_val) + config.parameters["gpucache"].CopyFrom(gpucache) + + gpucacheper_val = str(hugectr_params.get("gpucacheper_val", "0.5")) + + gpucacheper = model_config.ModelParameter(string_value=gpucacheper_val) + config.parameters["gpucacheper"].CopyFrom(gpucacheper) + + label_dim = model_config.ModelParameter(string_value=str(hugectr_params["label_dim"])) + config.parameters["label_dim"].CopyFrom(label_dim) + + slots = model_config.ModelParameter(string_value=str(hugectr_params["slots"])) + config.parameters["slots"].CopyFrom(slots) + + des_feature_num = model_config.ModelParameter( + string_value=str(hugectr_params["des_feature_num"]) + ) + config.parameters["des_feature_num"].CopyFrom(des_feature_num) + + cat_feature_num = model_config.ModelParameter( + string_value=str(hugectr_params["cat_feature_num"]) + ) + config.parameters["cat_feature_num"].CopyFrom(cat_feature_num) + + max_nnz = model_config.ModelParameter(string_value=str(hugectr_params["max_nnz"])) + config.parameters["max_nnz"].CopyFrom(max_nnz) + + embedding_vector_size = model_config.ModelParameter( + string_value=str(hugectr_params["embedding_vector_size"]) + ) + config.parameters["embedding_vector_size"].CopyFrom(embedding_vector_size) + + embeddingkey_long_type_val = hugectr_params.get("embeddingkey_long_type", "true") + + embeddingkey_long_type = model_config.ModelParameter(string_value=embeddingkey_long_type_val) + config.parameters["embeddingkey_long_type"].CopyFrom(embeddingkey_long_type) + + return config diff --git a/tests/unit/common/parsers/benchmark_parsers.py b/tests/unit/common/parsers/benchmark_parsers.py new file mode 100644 index 000000000..85ffbc62d --- /dev/null +++ b/tests/unit/common/parsers/benchmark_parsers.py @@ -0,0 +1,178 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import time + +from asvdb import BenchmarkResult + + +class Benchmark: + """ + Main general benchmark parsing class + """ + + def __init__(self, target_id, val=1, split=None): + self.name = f"{target_id}" + self.val = val + self.split = split + + def get_epoch(self, line): + raise NotImplementedError("Must Define logic for parsing metrics per epoch") + + def get_epochs(self, output): + raise NotImplementedError("Must Define logic for parsing output") + + +# Sub classes + + +class StandardBenchmark(Benchmark): + def get_info(self, output): + bench_infos = [] + losses = [] + for line in output: + if "run_time" in line: + bench_infos.append(line) + if "loss" in line: + losses.append(line) + loss_dict = {} + if losses: + loss_dict = {"loss": self.get_loss(losses[-1])} + if bench_infos: + bench_infos = self.get_dl_timing(bench_infos[-1:], optionals=loss_dict) + return bench_infos + + def get_dl_thru( + self, full_time, num_rows, epochs, throughput, optionals=None + ) -> BenchmarkResult: + metrics = [("thru", throughput), ("rows", num_rows), ("epochs", epochs)] + optionals = optionals or {} + for metric_name, metric_value in optionals.items(): + metrics.append((metric_name, metric_value)) + return create_bench_result( + f"{self.name}_dataloader", + metrics, + full_time, + "seconds", + ) + + def get_loss(self, line): + return float(line) + + def loss(self, epoch, loss, l_type="train") -> BenchmarkResult: + return create_bench_result( + f"{self.name}_{l_type}_loss", [("epoch", epoch)], loss, "percent" + ) + + def rmspe(self, epoch, rmspe) -> BenchmarkResult: + return create_bench_result(f"{self.name}_exp_rmspe", [("epoch", epoch)], rmspe, "percent") + + def acc(self, epoch, acc) -> BenchmarkResult: + return create_bench_result(f"{self.name}_exp_rmspe", [("epoch", epoch)], acc, "percent") + + def roc_auc(self, epoch, acc) -> BenchmarkResult: + return create_bench_result(f"{self.name}_exp_rmspe", [("epoch", epoch)], acc, "percent") + + def time(self, epoch, r_time, time_format="%M:%S") -> BenchmarkResult: + if time_format: + x = time.strptime(r_time.split(",")[0], time_format) + r_time = datetime.timedelta( + hours=x.tm_hour, minutes=x.tm_min, seconds=x.tm_sec + ).total_seconds() + return create_bench_result(f"{self.name}_time", [("epoch", epoch)], r_time, "seconds") + + def aps(self, epoch, aps) -> BenchmarkResult: + return create_bench_result(f"{self.name}_Avg_Prec", [("epoch", epoch)], aps, "percent") + + def get_dl_timing(self, output, optionals=None): + timing_res = [] + for line in output: + if line.startswith("run_time"): + run_time, num_rows, epochs, dl_thru = line.split(" - ") + run_time = float(run_time.split(": ")[1]) + num_rows = int(num_rows.split(": ")[1]) + epochs = int(epochs.split(": ")[1]) + dl_thru = float(dl_thru.split(": ")[1]) + bres = self.get_dl_thru( + run_time, num_rows * epochs, epochs, dl_thru, optionals=optionals + ) + timing_res.append(bres) + return timing_res[-1:] + + +class BenchFastAI(StandardBenchmark): + def __init__(self, target_id, val=6, split=None): + super().__init__(f"{target_id}_fastai", val=val, split=split) + + def get_epochs(self, output): + epochs = [] + for line in output: + split_line = line.split(self.split) if self.split else line.split() + if len(split_line) == self.val and is_whole_number(split_line[0]): + # epoch line, detected based on if 1st character is a number + post_evts = self.get_epoch(line) + epochs.append(post_evts) + if "run_time" in line: + epochs.append(self.get_dl_timing(line)) + return epochs[-1:] + + +# Utils + + +def is_whole_number(str_to_num): + try: + int(str_to_num) + return True + except ValueError: + return False + + +def is_float(str_to_flt): + try: + float(str_to_flt) + return True + except ValueError: + return False + + +def send_results(db, bench_info, results_list): + # only one entry because entries are split by Bench info + new_results_list = results_list + info_list = list(db.getInfo()) + if len(info_list) > 0: + br_list = db.getResults(filterInfoObjList=[bench_info]) + if br_list: + br_list = br_list[0][1] + results_to_remove = [] + for result in results_list: + if any(br.funcName == result.funcName for br in br_list): + results_to_remove.append(result) + new_results_list = [result for result in results_list if result not in results_to_remove] + # breakpoint() + for results in new_results_list: + if isinstance(results, list): + for result in results: + db.addResult(bench_info, result) + else: + db.addResult(bench_info, results) + + +def create_bench_result(name, arg_tuple_list, result, unit): + return BenchmarkResult( + funcName=name, argNameValuePairs=arg_tuple_list, unit=unit, result=result + ) diff --git a/tests/unit/common/parsers/criteo_parsers.py b/tests/unit/common/parsers/criteo_parsers.py new file mode 100644 index 000000000..a5b23ceb8 --- /dev/null +++ b/tests/unit/common/parsers/criteo_parsers.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re + +from tests.integration.common.parsers.benchmark_parsers import ( + BenchFastAI, + StandardBenchmark, + create_bench_result, +) + +decimal_regex = "[0-9]+\.?[0-9]*|\.[0-9]+" # noqa pylint: disable=W1401 + + +class CriteoBenchFastAI(BenchFastAI): + def __init__(self, name="CriteoFastAI", val=6, split=None): + self.name = name + self.val = val + self.split = split + + def get_info(self, output): + bench_infos = [] + losses = [] + for line in output: + if "run_time" in line: + bench_infos.append(line) + if "loss" in line and "Train" in line and "Valid" in line: + losses.append(line) + loss_dict = {} + if losses: + for loss in losses: + t_loss, v_loss = self.get_loss(loss) + loss_dict["loss_train"] = t_loss + loss_dict["loss_valid"] = v_loss + if bench_infos: + bench_infos = self.get_dl_timing(bench_infos[-1:], optionals=loss_dict) + return bench_infos + + def get_epoch(self, line): + epoch, t_loss, v_loss, roc, aps, o_time = line.split() + t_loss = self.loss(epoch, float(t_loss)) + v_loss = self.loss(epoch, float(v_loss), l_type="valid") + roc = self.roc_auc(epoch, float(roc)) + aps = self.aps(epoch, float(aps)) + return [t_loss, v_loss, roc, aps, o_time] + + def get_loss(self, line): + epoch, t_loss, v_loss, roc, aps, o_time = line.split() + t_loss = float(t_loss) + v_loss = float(v_loss) + return [t_loss, v_loss] + + +class CriteoBenchHugeCTR(StandardBenchmark): + def __init__(self, name="CriteoHugeCTR"): + self.name = name + + def get_epochs(self, output): + aucs = [] + for line in output: + if "AUC:" in line: + auc_num = float(line.split("AUC:")[-1]) + aucs.append(auc_num) + if "run_time:" in line: + run_time = self.get_runtime(line) + if run_time and aucs: + return self.get_epoch(max(aucs), run_time) + return [] + + def get_runtime(self, line): + split_line = line.split(":") + return float(split_line[1]) + + def get_epoch(self, auc, runtime): + bres_auc = create_bench_result(f"{self.name}_auc", [("time", runtime)], auc, "percent") + return [bres_auc] + + +class CriteoTensorflow(StandardBenchmark): + def __init__(self, name="CriteoTensorFlow"): + self.name = name + + def get_loss(self, line): + loss = line.split("-")[-1] + loss = loss.split(":")[-1] + losses = re.findall(decimal_regex, loss) + losses = losses or [] + return float(losses[-1]) + + +class CriteoTorch(StandardBenchmark): + def __init__(self, name="CriteoTorch"): + self.name = name + + def get_info(self, output): + bench_infos = [] + losses = [] + for line in output: + if "run_time" in line: + bench_infos.append(line) + if "loss" in line and "Train" in line and "Valid" in line: + losses.append(line) + loss_dict = {} + if losses: + for idx, loss in enumerate(losses): + t_loss, v_loss = self.get_loss(loss) + loss_dict["loss_train"] = t_loss + loss_dict["loss_valid"] = v_loss + if bench_infos: + bench_infos = self.get_dl_timing(bench_infos[-1:], optionals=loss_dict) + return bench_infos + + def get_loss(self, line): + # Epoch 00. Train loss: 0.1944. Valid loss: 0.1696. + loss_parse = line.split(". ") + epoch = loss_parse[0].split(" ")[-1] + train_loss = loss_parse[1].split(":")[-1] + valid_loss = loss_parse[2].split(":")[-1] + + epoch = re.findall(decimal_regex, epoch)[-1] + train_loss = re.findall(decimal_regex, train_loss)[-1] + valid_loss = re.findall(decimal_regex, valid_loss)[-1] + + epoch = int(epoch) + train_loss = float(train_loss) + valid_loss = float(valid_loss) + return [train_loss, valid_loss] diff --git a/tests/unit/common/parsers/rossmann_parsers.py b/tests/unit/common/parsers/rossmann_parsers.py new file mode 100644 index 000000000..7538065f8 --- /dev/null +++ b/tests/unit/common/parsers/rossmann_parsers.py @@ -0,0 +1,77 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from tests.integration.common.parsers.benchmark_parsers import BenchFastAI, StandardBenchmark + + +class RossBenchTensorFlow(StandardBenchmark): + def __init__(self, split=" - "): + super().__init__("Rossmann_tf", split=split) + + def get_epoch(self, line, epoch=0): + _, _, t_loss, t_rmspe = line.split(self.split) + t_loss = self.loss(epoch, float(t_loss.split(": ")[1])) + # t_rmspe = self.rmspe(epoch, float(t_rmspe.split(": ")[1])) + return [t_loss, t_rmspe] + + def get_epochs(self, output): + epochs = [] + for idx, line in enumerate(output): + if "Epoch" in line: + epoch = int(line.split()[-1].split("/")[0]) + # output skips line for formatting and remove returns (\x08) + content_line = output[idx + 2].rstrip("\x08") + # epoch line, detected based on if 1st character is a number + post_evts = self.get_epoch(content_line, epoch=epoch) + epochs.append(post_evts) + if "run_time" in line: + epochs.append(self.get_dl_timing(line)) + return epochs[-1:] + + +class RossBenchPytorch(StandardBenchmark): + def __init__(self, split=". "): + super().__init__("Rossmann_torch", split=split) + + def get_epoch(self, line): + epoch, t_loss, t_rmspe, v_loss, v_rmspe = line.split(self.split) + epoch = epoch.split()[1] + t_loss = self.loss(epoch, float(t_loss.split(": ")[1])) + v_loss = self.loss(epoch, float(v_loss.split(": ")[1]), l_type="valid") + return [t_loss, v_loss, t_rmspe, v_rmspe] + + def get_epochs(self, output): + epochs = [] + for line in output: + if "Epoch" in line: + # epoch line, detected based on if 1st character is a number + post_evts = self.get_epoch(line) + epochs.append(post_evts) + if "run_time" in line: + epochs.append(self.get_dl_timing(line)) + return epochs[-1:] + + +class RossBenchFastAI(BenchFastAI): + def __init__(self, val=5, split=None): + super().__init__("Rossmann", val=val, split=split) + + def get_epoch(self, line): + epoch, t_loss, v_loss, exp_rmspe, o_time = line.split() + t_loss = self.loss(epoch, float(t_loss)) + v_loss = self.loss(epoch, float(v_loss), l_type="valid") + # exp_rmspe = self.rmspe(epoch, float(exp_rmspe)) + return [t_loss, v_loss, exp_rmspe, o_time] diff --git a/tests/unit/common/utils.py b/tests/unit/common/utils.py new file mode 100644 index 000000000..9f2271457 --- /dev/null +++ b/tests/unit/common/utils.py @@ -0,0 +1,150 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime as dt +import itertools +import json +import os +import shutil +import subprocess +import sys + +import cudf +import cupy as cp + +import nvtabular as nvt + + +def _run_notebook( + tmpdir, + notebook_path, + input_path, + output_path, + batch_size=None, + gpu_id=0, + clean_up=True, + transform=None, + params=None, + main_block=-1, +): + params = params or [] + + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("GPU_TARGET_ID", gpu_id) + + if not os.path.exists(input_path): + os.makedirs(input_path) + if not os.path.exists(output_path): + os.makedirs(output_path) + if batch_size: + os.environ["BATCH_SIZE"] = os.environ.get("BATCH_SIZE", batch_size) + + os.environ["INPUT_DATA_DIR"] = input_path + os.environ["OUTPUT_DATA_DIR"] = output_path + # read in the notebook as JSON, and extract a python script from it + notebook = json.load(open(notebook_path, encoding="utf-8")) + source_cells = [cell["source"] for cell in notebook["cells"] if cell["cell_type"] == "code"] + + lines = [ + transform(line.rstrip()) if transform else line + for line in itertools.chain(*source_cells) + if not (line.startswith("%") or line.startswith("!")) + ] + + # Replace config params + if params: + + def transform_fracs(line): + line = line.replace("device_limit_frac = 0.7", "device_limit_frac = " + str(params[0])) + line = line.replace("device_pool_frac = 0.8", "device_pool_frac = " + str(params[1])) + return line.replace("part_mem_frac = 0.15", "part_mem_frac = " + str(params[2])) + + lines = [transform_fracs(line) for line in lines] + + # Add guarding block and indentation + if main_block >= 0: + lines.insert(main_block, 'if __name__ == "__main__":') + for i in range(main_block + 1, len(lines)): + lines[i] = " " + lines[i] + + # save the script to a file, and run with the current python executable + # we're doing this in a subprocess to avoid some issues using 'exec' + # that were causing a segfault with globals of the exec'ed function going + # out of scope + script_path = os.path.join(tmpdir, "notebook.py") + with open(script_path, "w") as script: + script.write("\n".join(lines)) + output = subprocess.check_output([sys.executable, script_path]) + # save location will default to run location + output = output.decode("utf-8") + _, note_name = os.path.split(notebook_path) + note_name = note_name.split(".")[0] + if output: + with open(f"test_res_{note_name}", "w+") as w_file: + w_file.write(output) + # clear out products + if clean_up: + shutil.rmtree(output_path) + return output + + +def _run_query( + client, + n_rows, + model_name, + workflow_path, + data_path, + actual_output_filename, + output_name, + input_cols_name=None, + backend="tensorflow", +): + + import tritonclient.grpc as grpcclient + from tritonclient.utils import np_to_triton_dtype + + workflow = nvt.Workflow.load(workflow_path) + + if input_cols_name is None: + batch = cudf.read_csv(data_path, nrows=n_rows)[workflow.output_node.input_columns.names] + else: + batch = cudf.read_csv(data_path, nrows=n_rows)[input_cols_name] + + input_dtypes = workflow.input_dtypes + columns = [(col, batch[col]) for col in batch.columns] + + inputs = [] + for i, (name, col) in enumerate(columns): + d = col.values_host.astype(input_dtypes[name]) + d = d.reshape(len(d), 1) + inputs.append(grpcclient.InferInput(name, d.shape, np_to_triton_dtype(input_dtypes[name]))) + inputs[i].set_data_from_numpy(d) + + outputs = [grpcclient.InferRequestedOutput(output_name)] + time_start = dt.datetime.now() + response = client.infer(model_name, inputs, request_id="1", outputs=outputs) + run_time = dt.datetime.now() - time_start + + output_key = "output" if backend == "hugectr" else "0" + + output_actual = cudf.read_csv(os.path.expanduser(actual_output_filename), nrows=n_rows) + output_actual = cp.asnumpy(output_actual[output_key].values) + output_predict = response.as_numpy(output_name) + + if backend == "tensorflow": + output_predict = output_predict[:, 0] + + diff = abs(output_actual - output_predict) + return diff, run_time diff --git a/tests/unit/systems/hugectr/__init__.py b/tests/unit/systems/hugectr/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/systems/hugectr/test_hugectr.py b/tests/unit/systems/hugectr/test_hugectr.py new file mode 100644 index 000000000..fa4d34480 --- /dev/null +++ b/tests/unit/systems/hugectr/test_hugectr.py @@ -0,0 +1,312 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import gc +import os +import shutil +from os import path + +import cudf + +from merlin.systems.dag.ops.hugectr import HugeCTR + +try: + import hugectr + from hugectr.inference import CreateInferenceSession, InferenceParams + from mpi4py import MPI # noqa pylint: disable=unused-import +except ImportError: + hugectr = None + +from distutils.spawn import find_executable + +import numpy as np +import pandas as pd + +# from common.parsers.benchmark_parsers import create_bench_result +# from common.utils import _run_query +from sklearn.model_selection import train_test_split + +import nvtabular as nvt +from merlin.core.utils import download_file +from nvtabular.ops import get_embedding_sizes + +DIR = "/raid/data/movielens/data/" +DATA_DIR = DIR + "data/" +TEMP_DIR = DIR + "temp_hugectr/" +MODEL_DIR = DIR + "models/" +TRAIN_DIR = MODEL_DIR + "test_model/1/" +NETWORK_FILE = TRAIN_DIR + "model.json" +DENSE_FILE = TRAIN_DIR + "_dense_1900.model" +SPARSE_FILES = TRAIN_DIR + "0_sparse_1900.model" +MODEL_NAME = "test_model" + +CATEGORICAL_COLUMNS = ["userId", "movieId", "new_cat1"] +LABEL_COLUMNS = ["rating"] +TEST_N_ROWS = 64 + +TRITON_SERVER_PATH = find_executable("tritonserver") +TRITON_DEVICE_ID = "1" + + +def _run_model(slot_sizes, total_cardinality): + solver = hugectr.CreateSolver( + vvgpu=[[0]], + batchsize=2048, + batchsize_eval=2048, + max_eval_batches=160, + i64_input_key=True, + use_mixed_precision=False, + repeat_dataset=True, + ) + + reader = hugectr.DataReaderParams( + data_reader_type=hugectr.DataReaderType_t.Parquet, + source=[DATA_DIR + "train/_file_list.txt"], + eval_source=DATA_DIR + "valid/_file_list.txt", + check_type=hugectr.Check_t.Non, + ) + + optimizer = hugectr.CreateOptimizer(optimizer_type=hugectr.Optimizer_t.Adam) + model = hugectr.Model(solver, reader, optimizer) + + model.add( + hugectr.Input( + label_dim=1, + label_name="label", + dense_dim=0, + dense_name="dense", + data_reader_sparse_param_array=[ + hugectr.DataReaderSparseParam("data1", len(slot_sizes) + 1, True, len(slot_sizes)) + ], + ) + ) + + model.add( + hugectr.SparseEmbedding( + embedding_type=hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, + workspace_size_per_gpu_in_mb=107, + embedding_vec_size=16, + combiner="sum", + sparse_embedding_name="sparse_embedding1", + bottom_name="data1", + slot_size_array=slot_sizes, + optimizer=optimizer, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.Reshape, + bottom_names=["sparse_embedding1"], + top_names=["reshape1"], + leading_dim=48, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.InnerProduct, + bottom_names=["reshape1"], + top_names=["fc1"], + num_output=128, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.ReLU, + bottom_names=["fc1"], + top_names=["relu1"], + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.InnerProduct, + bottom_names=["relu1"], + top_names=["fc2"], + num_output=128, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.ReLU, + bottom_names=["fc2"], + top_names=["relu2"], + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.InnerProduct, + bottom_names=["relu2"], + top_names=["fc3"], + num_output=1, + ) + ) + model.add( + hugectr.DenseLayer( + layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss, + bottom_names=["fc3", "label"], + top_names=["loss"], + ) + ) + model.compile() + model.summary() + model.fit(max_iter=20, display=100, eval_interval=200, snapshot=10) + model.graph_to_json(graph_config_file=NETWORK_FILE) + + return model + + +def _predict(dense_features, embedding_columns, row_ptrs, config_file, model_name): + inference_params = InferenceParams( + model_name=model_name, + max_batchsize=64, + hit_rate_threshold=0.5, + dense_model_file=DENSE_FILE, + sparse_model_files=[SPARSE_FILES], + device_id=0, + use_gpu_embedding_cache=True, + cache_size_percentage=0.1, + i64_input_key=True, + use_mixed_precision=False, + ) + inference_session = CreateInferenceSession(config_file, inference_params) + output = inference_session.predict(dense_features, embedding_columns, row_ptrs) # , True) + + test_data_path = DATA_DIR + "test/" + embedding_columns_df = pd.DataFrame() + embedding_columns_df["embedding_columns"] = embedding_columns + embedding_columns_df.to_csv(test_data_path + "embedding_columns.csv") + + row_ptrs_df = pd.DataFrame() + row_ptrs_df["row_ptrs"] = row_ptrs + row_ptrs_df.to_csv(test_data_path + "row_ptrs.csv") + + output_df = pd.DataFrame() + output_df["output"] = output + output_df.to_csv(test_data_path + "output.csv") + + +def _convert(data, slot_size_array): + categorical_dim = len(CATEGORICAL_COLUMNS) + batch_size = data.shape[0] + + offset = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1].tolist() + data[CATEGORICAL_COLUMNS] += offset + cat = data[CATEGORICAL_COLUMNS].values.reshape(1, batch_size * categorical_dim).tolist()[0] + + row_ptrs = list(range(batch_size * categorical_dim + 1)) + dense = [] + + return dense, cat, row_ptrs + + +def test_training(tmpdir): + # Download & Convert data + download_file( + "http://files.grouplens.org/datasets/movielens/ml-25m.zip", + os.path.join(DATA_DIR, "ml-25m.zip"), + ) + + ratings = cudf.read_csv(os.path.join(DATA_DIR, "ml-25m", "ratings.csv")) + ratings["new_cat1"] = ratings["userId"] / ratings["movieId"] + ratings["new_cat1"] = ratings["new_cat1"].astype("int64") + ratings.head() + + ratings = ratings.drop("timestamp", axis=1) + train, valid = train_test_split(ratings, test_size=0.2, random_state=42) + + train.to_parquet(DATA_DIR + "train.parquet") + valid.to_parquet(DATA_DIR + "valid.parquet") + + del train + del valid + gc.collect() + + # Perform ETL with NVTabular + cat_features = CATEGORICAL_COLUMNS >> nvt.ops.Categorify(cat_cache="device") + ratings = nvt.ColumnSelector(["rating"]) >> nvt.ops.LambdaOp( + lambda col: (col > 3).astype("int8") + ) + output = cat_features + ratings + + workflow = nvt.Workflow(output) + + train_dataset = nvt.Dataset(DATA_DIR + "train.parquet", part_size="100MB") + valid_dataset = nvt.Dataset(DATA_DIR + "valid.parquet", part_size="100MB") + + workflow.fit(train_dataset) + + dict_dtypes = {} + + for col in CATEGORICAL_COLUMNS: + dict_dtypes[col] = np.int64 + + for col in LABEL_COLUMNS: + dict_dtypes[col] = np.float32 + + if path.exists(DATA_DIR + "train"): + shutil.rmtree(os.path.join(DATA_DIR, "train")) + if path.exists(DATA_DIR + "valid"): + shutil.rmtree(os.path.join(DATA_DIR, "valid")) + + workflow.transform(train_dataset).to_parquet( + output_path=DATA_DIR + "train/", + shuffle=nvt.io.Shuffle.PER_PARTITION, + cats=CATEGORICAL_COLUMNS, + labels=LABEL_COLUMNS, + dtypes=dict_dtypes, + ) + workflow.transform(valid_dataset).to_parquet( + output_path=DATA_DIR + "valid/", + shuffle=False, + cats=CATEGORICAL_COLUMNS, + labels=LABEL_COLUMNS, + dtypes=dict_dtypes, + ) + + # Train with HugeCTR + embeddings = get_embedding_sizes(workflow) + total_cardinality = 0 + slot_sizes = [] + for column in CATEGORICAL_COLUMNS: + slot_sizes.append(embeddings[column][0]) + total_cardinality += embeddings[column][0] + + test_data_path = DATA_DIR + "test/" + if path.exists(test_data_path): + shutil.rmtree(test_data_path) + + os.mkdir(test_data_path) + + if path.exists(MODEL_DIR): + shutil.rmtree(MODEL_DIR) + + os.makedirs(TRAIN_DIR) + + sample_data = cudf.read_parquet(DATA_DIR + "valid.parquet", num_rows=TEST_N_ROWS) + sample_data.to_csv(test_data_path + "data.csv") + + sample_data_trans = nvt.workflow.workflow._transform_partition( + sample_data, [workflow.output_node] + ) + + dense_features, embedding_columns, row_ptrs = _convert(sample_data_trans, slot_sizes) + + model = _run_model(slot_sizes, total_cardinality) + + model_op = HugeCTR(model) + config = model_op.export(tmpdir, None, None) + assert config is not None + + # _predict(dense_features, embedding_columns, row_ptrs, hugectr_params["config"], MODEL_NAME) From 4b94dfa0748db946f2f25af895a0600aa37fd49c Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Mon, 27 Jun 2022 10:13:41 -0400 Subject: [PATCH 02/10] hugectr snapshot --- tests/unit/systems/hugectr/test_hugectr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/systems/hugectr/test_hugectr.py b/tests/unit/systems/hugectr/test_hugectr.py index fa4d34480..6d16b3e59 100644 --- a/tests/unit/systems/hugectr/test_hugectr.py +++ b/tests/unit/systems/hugectr/test_hugectr.py @@ -309,4 +309,4 @@ def test_training(tmpdir): config = model_op.export(tmpdir, None, None) assert config is not None - # _predict(dense_features, embedding_columns, row_ptrs, hugectr_params["config"], MODEL_NAME) + # _predict(dense_features, embedding_columns, row_ptrs, tmpdir, MODEL_NAME) From d7e1382e69cfc1056cb484ce233eace3c82498e1 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Thu, 30 Jun 2022 12:04:31 -0400 Subject: [PATCH 03/10] hugectr op is green for single hot columns --- merlin/systems/dag/ops/hugectr.py | 91 +++--- tests/unit/systems/hugectr/test_hugectr.py | 314 +++++++++------------ tests/unit/systems/utils/triton.py | 12 +- 3 files changed, 191 insertions(+), 226 deletions(-) diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py index e3f8e071f..2d5172613 100644 --- a/merlin/systems/dag/ops/hugectr.py +++ b/merlin/systems/dag/ops/hugectr.py @@ -35,45 +35,28 @@ class HugeCTR(InferenceOperator): def __init__( self, model, - max_batch_size=1024, + max_batch_size=64, device_list=None, hit_rate_threshold=None, gpucache=None, freeze_sparse=None, gpucacheper=None, - label_dim=None, - slots=None, - cat_feature_num=None, - des_feature_num=None, - max_nnz=None, - embedding_vector_size=None, + max_nnz=2, embeddingkey_long_type=None, ): self.model = model - self.max_batch_size = max_batch_size or 1024 - self.device_list = device_list - - # if isinstance(model_or_path,(str, os.PathLike)): - # self.path = model_or_path - # elif isinstance(model_or_path, hugectr.Model): - # self.model = model_or_path - # else: - # raise ValueError( - # "Unsupported type for model_or_path. " - # "Must be pathlike or hugectr.Model" - # ) + self.max_batch_size = max_batch_size + self.device_list = device_list or [] + embeddingkey_long_type = embeddingkey_long_type or "true" + gpucache = gpucache or "true" + gpucacheper = gpucacheper or 0.5 self.hugectr_params = dict( hit_rate_threshold=hit_rate_threshold, gpucache=gpucache, freeze_sparse=freeze_sparse, gpucacheper=gpucacheper, - label_dim=label_dim, - slots=slots, - cat_feature_num=cat_feature_num, - des_feature_num=des_feature_num, max_nnz=max_nnz, - embedding_vector_size=embedding_vector_size, embeddingkey_long_type=embeddingkey_long_type, ) @@ -160,51 +143,76 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): node_name = f"{node_id}_{self.export_name}" if node_id is not None else self.export_name node_export_path = pathlib.Path(path) / node_name node_export_path.mkdir(exist_ok=True) - + model_name = node_name hugectr_model_path = pathlib.Path(node_export_path) / str(version) hugectr_model_path.mkdir(exist_ok=True) - self.model.graph_to_json(graph_config_file=str(hugectr_model_path / "model.json")) + + network_file = os.path.join(hugectr_model_path, f"{model_name}.json") + + self.model.graph_to_json(graph_config_file=network_file) self.model.save_params_to_files(str(hugectr_model_path) + "/") - # generate config - # save artifacts to model repository (path) - # {node_id}_hugectr/config.pbtxt - # {node_id}_hugectr/1/ - model_name = "model" + model_json = json.loads(open(network_file, "r").read()) dense_pattern = "*_dense_*.model" dense_path = [ os.path.join(hugectr_model_path, path.name) for path in hugectr_model_path.glob(dense_pattern) + if "opt" not in path.name ][0] sparse_pattern = "*_sparse_*.model" sparse_paths = [ os.path.join(hugectr_model_path, path.name) for path in hugectr_model_path.glob(sparse_pattern) + if "opt" not in path.name ] - network_file = os.path.join(hugectr_model_path, f"{model_name}.json") config_dict = dict() config_dict["supportlonglong"] = True + + data_layer = model_json["layers"][0] + sparse_layers = [ + layer + for layer in model_json["layers"] + if layer["type"] == "DistributedSlotSparseEmbeddingHash" + ] + num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"]) + vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers] + model = dict() model["model"] = model_name + model["slot_num"] = num_cat_columns model["sparse_files"] = sparse_paths model["dense_file"] = dense_path + model["maxnum_des_feature_per_sample"] = data_layer["dense"]["dense_dim"] model["network_file"] = network_file model["num_of_worker_buffer_in_pool"] = 4 model["num_of_refresher_buffer_in_pool"] = 1 model["deployed_device_list"] = self.device_list - model["max_batch_size"] = (self.max_batch_size,) - model["default_value_for_each_table"] = [0.0] + model["max_batch_size"] = self.max_batch_size + model["default_value_for_each_table"] = [0.0] * len(sparse_layers) model["hit_rate_threshold"] = 0.9 - model["gpucacheper"] = 0.5 + model["gpucacheper"] = self.hugectr_params["gpucacheper"] model["gpucache"] = True model["cache_refresh_percentage_per_iteration"] = 0.2 + model["maxnum_catfeature_query_per_table_per_sample"] = [ + len(x["sparse_embedding_hparam"]["slot_size_array"]) for x in sparse_layers + ] + model["embedding_vecsize_per_table"] = vec_size + model["embedding_table_names"] = [x["top"] for x in sparse_layers] config_dict["models"] = [model] - parameter_server_config_path = str(hugectr_model_path / "ps.json") + parameter_server_config_path = str(node_export_path.parent / "ps.json") with open(parameter_server_config_path, "w") as f: f.write(json.dumps(config_dict)) - self.hugectr_params["config"] = parameter_server_config_path + self.hugectr_params["config"] = network_file + + # These are no longer required from hugectr_backend release 3.7 + self.hugectr_params["cat_feature_num"] = num_cat_columns + self.hugectr_params["des_feature_num"] = data_layer["dense"]["dense_dim"] + self.hugectr_params["embedding_vector_size"] = vec_size[0] + self.hugectr_params["slots"] = num_cat_columns + self.hugectr_params["label_dim"] = data_layer["label"]["label_dim"] + config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size) with open(os.path.join(node_export_path, "config.pbtxt"), "w") as o: @@ -253,13 +261,11 @@ def _hugectr_config(name, hugectr_params, max_batch_size=None): config_hugectr = model_config.ModelParameter(string_value=hugectr_params["config"]) config.parameters["config"].CopyFrom(config_hugectr) - gpucache_val = hugectr_params.get("gpucache", "true") - + gpucache_val = hugectr_params["gpucache"] gpucache = model_config.ModelParameter(string_value=gpucache_val) config.parameters["gpucache"].CopyFrom(gpucache) - gpucacheper_val = str(hugectr_params.get("gpucacheper_val", "0.5")) - + gpucacheper_val = str(hugectr_params["gpucacheper"]) gpucacheper = model_config.ModelParameter(string_value=gpucacheper_val) config.parameters["gpucacheper"].CopyFrom(gpucacheper) @@ -287,8 +293,7 @@ def _hugectr_config(name, hugectr_params, max_batch_size=None): ) config.parameters["embedding_vector_size"].CopyFrom(embedding_vector_size) - embeddingkey_long_type_val = hugectr_params.get("embeddingkey_long_type", "true") - + embeddingkey_long_type_val = hugectr_params["embeddingkey_long_type"] embeddingkey_long_type = model_config.ModelParameter(string_value=embeddingkey_long_type_val) config.parameters["embeddingkey_long_type"].CopyFrom(embeddingkey_long_type) diff --git a/tests/unit/systems/hugectr/test_hugectr.py b/tests/unit/systems/hugectr/test_hugectr.py index 6d16b3e59..837b794cc 100644 --- a/tests/unit/systems/hugectr/test_hugectr.py +++ b/tests/unit/systems/hugectr/test_hugectr.py @@ -14,14 +14,18 @@ # limitations under the License. # -import gc import os -import shutil -from os import path import cudf +import numpy as np +import pytest +import nvtabular as nvt +from merlin.dag import ColumnSelector +from merlin.schema import ColumnSchema, Schema +from merlin.systems.dag.ensemble import Ensemble from merlin.systems.dag.ops.hugectr import HugeCTR +from tests.unit.systems.utils.triton import _run_ensemble_on_tritonserver try: import hugectr @@ -30,52 +34,28 @@ except ImportError: hugectr = None -from distutils.spawn import find_executable - -import numpy as np -import pandas as pd +triton = pytest.importorskip("merlin.systems.triton") +grpcclient = pytest.importorskip("tritonclient.grpc") # from common.parsers.benchmark_parsers import create_bench_result # from common.utils import _run_query -from sklearn.model_selection import train_test_split - -import nvtabular as nvt -from merlin.core.utils import download_file -from nvtabular.ops import get_embedding_sizes - -DIR = "/raid/data/movielens/data/" -DATA_DIR = DIR + "data/" -TEMP_DIR = DIR + "temp_hugectr/" -MODEL_DIR = DIR + "models/" -TRAIN_DIR = MODEL_DIR + "test_model/1/" -NETWORK_FILE = TRAIN_DIR + "model.json" -DENSE_FILE = TRAIN_DIR + "_dense_1900.model" -SPARSE_FILES = TRAIN_DIR + "0_sparse_1900.model" -MODEL_NAME = "test_model" - -CATEGORICAL_COLUMNS = ["userId", "movieId", "new_cat1"] -LABEL_COLUMNS = ["rating"] -TEST_N_ROWS = 64 -TRITON_SERVER_PATH = find_executable("tritonserver") -TRITON_DEVICE_ID = "1" - -def _run_model(slot_sizes, total_cardinality): +def _run_model(slot_sizes, source, dense_dim): solver = hugectr.CreateSolver( vvgpu=[[0]], - batchsize=2048, - batchsize_eval=2048, - max_eval_batches=160, + batchsize=10, + batchsize_eval=10, + max_eval_batches=50, i64_input_key=True, use_mixed_precision=False, repeat_dataset=True, ) - + # https://github.com/NVIDIA-Merlin/HugeCTR/blob/9e648f879166fc93931c676a5594718f70178a92/docs/source/api/python_interface.md#datareaderparams reader = hugectr.DataReaderParams( data_reader_type=hugectr.DataReaderType_t.Parquet, - source=[DATA_DIR + "train/_file_list.txt"], - eval_source=DATA_DIR + "valid/_file_list.txt", + source=[os.path.join(source, "_file_list.txt")], + eval_source=os.path.join(source, "_file_list.txt"), check_type=hugectr.Check_t.Non, ) @@ -86,14 +66,13 @@ def _run_model(slot_sizes, total_cardinality): hugectr.Input( label_dim=1, label_name="label", - dense_dim=0, + dense_dim=dense_dim, dense_name="dense", data_reader_sparse_param_array=[ hugectr.DataReaderSparseParam("data1", len(slot_sizes) + 1, True, len(slot_sizes)) ], ) ) - model.add( hugectr.SparseEmbedding( embedding_type=hugectr.Embedding_t.DistributedSlotSparseEmbeddingHash, @@ -106,207 +85,188 @@ def _run_model(slot_sizes, total_cardinality): optimizer=optimizer, ) ) - model.add( - hugectr.DenseLayer( - layer_type=hugectr.Layer_t.Reshape, - bottom_names=["sparse_embedding1"], - top_names=["reshape1"], - leading_dim=48, - ) - ) model.add( hugectr.DenseLayer( layer_type=hugectr.Layer_t.InnerProduct, - bottom_names=["reshape1"], + bottom_names=["dense"], top_names=["fc1"], - num_output=128, + num_output=512, ) ) model.add( hugectr.DenseLayer( - layer_type=hugectr.Layer_t.ReLU, - bottom_names=["fc1"], - top_names=["relu1"], + layer_type=hugectr.Layer_t.Reshape, + bottom_names=["sparse_embedding1"], + top_names=["reshape1"], + leading_dim=48, ) ) model.add( hugectr.DenseLayer( layer_type=hugectr.Layer_t.InnerProduct, - bottom_names=["relu1"], + bottom_names=["reshape1", "fc1"], top_names=["fc2"], - num_output=128, - ) - ) - model.add( - hugectr.DenseLayer( - layer_type=hugectr.Layer_t.ReLU, - bottom_names=["fc2"], - top_names=["relu2"], - ) - ) - model.add( - hugectr.DenseLayer( - layer_type=hugectr.Layer_t.InnerProduct, - bottom_names=["relu2"], - top_names=["fc3"], num_output=1, ) ) model.add( hugectr.DenseLayer( layer_type=hugectr.Layer_t.BinaryCrossEntropyLoss, - bottom_names=["fc3", "label"], + bottom_names=["fc2", "label"], top_names=["loss"], ) ) model.compile() model.summary() model.fit(max_iter=20, display=100, eval_interval=200, snapshot=10) - model.graph_to_json(graph_config_file=NETWORK_FILE) return model -def _predict(dense_features, embedding_columns, row_ptrs, config_file, model_name): - inference_params = InferenceParams( - model_name=model_name, - max_batchsize=64, - hit_rate_threshold=0.5, - dense_model_file=DENSE_FILE, - sparse_model_files=[SPARSE_FILES], - device_id=0, - use_gpu_embedding_cache=True, - cache_size_percentage=0.1, - i64_input_key=True, - use_mixed_precision=False, - ) - inference_session = CreateInferenceSession(config_file, inference_params) - output = inference_session.predict(dense_features, embedding_columns, row_ptrs) # , True) - - test_data_path = DATA_DIR + "test/" - embedding_columns_df = pd.DataFrame() - embedding_columns_df["embedding_columns"] = embedding_columns - embedding_columns_df.to_csv(test_data_path + "embedding_columns.csv") - - row_ptrs_df = pd.DataFrame() - row_ptrs_df["row_ptrs"] = row_ptrs - row_ptrs_df.to_csv(test_data_path + "row_ptrs.csv") - - output_df = pd.DataFrame() - output_df["output"] = output - output_df.to_csv(test_data_path + "output.csv") - - -def _convert(data, slot_size_array): - categorical_dim = len(CATEGORICAL_COLUMNS) +def _convert(data, slot_size_array, categorical_columns, labels=None): + labels = labels or [] + dense_columns = list(set(data.columns) - set(categorical_columns + labels)) + categorical_dim = len(categorical_columns) batch_size = data.shape[0] - offset = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1].tolist() - data[CATEGORICAL_COLUMNS] += offset - cat = data[CATEGORICAL_COLUMNS].values.reshape(1, batch_size * categorical_dim).tolist()[0] + shift = np.insert(np.cumsum(slot_size_array), 0, 0)[:-1].tolist() - row_ptrs = list(range(batch_size * categorical_dim + 1)) - dense = [] + # These dtypes are static for HugeCTR + dense = np.array([data[dense_columns].values.flatten().tolist()], dtype="float32") + cat = np.array([(data[categorical_columns] + shift).values.flatten().tolist()], dtype="int64") + rowptr = np.array([list(range(batch_size * categorical_dim + 1))], dtype="int32") - return dense, cat, row_ptrs + return dense, cat, rowptr def test_training(tmpdir): - # Download & Convert data - download_file( - "http://files.grouplens.org/datasets/movielens/ml-25m.zip", - os.path.join(DATA_DIR, "ml-25m.zip"), + cat_dtypes = {"a": int, "b": int, "c": int} + dataset = cudf.datasets.randomdata(1, dtypes={**cat_dtypes, "label": bool}) + dataset["label"] = dataset["label"].astype("int32") + + categorical_columns = list(cat_dtypes.keys()) + + gdf = cudf.DataFrame( + { + "a": np.arange(64), + "b": np.arange(64), + "c": np.arange(64), + "d": np.random.rand(64).tolist(), + "label": [0] * 64, + }, + dtype="int64", ) + gdf["label"] = gdf["label"].astype("float32") + train_dataset = nvt.Dataset(gdf) - ratings = cudf.read_csv(os.path.join(DATA_DIR, "ml-25m", "ratings.csv")) - ratings["new_cat1"] = ratings["userId"] / ratings["movieId"] - ratings["new_cat1"] = ratings["new_cat1"].astype("int64") - ratings.head() - - ratings = ratings.drop("timestamp", axis=1) - train, valid = train_test_split(ratings, test_size=0.2, random_state=42) - - train.to_parquet(DATA_DIR + "train.parquet") - valid.to_parquet(DATA_DIR + "valid.parquet") - - del train - del valid - gc.collect() - - # Perform ETL with NVTabular - cat_features = CATEGORICAL_COLUMNS >> nvt.ops.Categorify(cat_cache="device") - ratings = nvt.ColumnSelector(["rating"]) >> nvt.ops.LambdaOp( - lambda col: (col > 3).astype("int8") - ) - output = cat_features + ratings - - workflow = nvt.Workflow(output) - - train_dataset = nvt.Dataset(DATA_DIR + "train.parquet", part_size="100MB") - valid_dataset = nvt.Dataset(DATA_DIR + "valid.parquet", part_size="100MB") - - workflow.fit(train_dataset) + dense_columns = ["d"] dict_dtypes = {} + for col in dense_columns: + dict_dtypes[col] = np.float32 - for col in CATEGORICAL_COLUMNS: + for col in categorical_columns: dict_dtypes[col] = np.int64 - for col in LABEL_COLUMNS: + for col in ["label"]: dict_dtypes[col] = np.float32 - if path.exists(DATA_DIR + "train"): - shutil.rmtree(os.path.join(DATA_DIR, "train")) - if path.exists(DATA_DIR + "valid"): - shutil.rmtree(os.path.join(DATA_DIR, "valid")) + train_path = os.path.join(tmpdir, "train/") + os.mkdir(train_path) - workflow.transform(train_dataset).to_parquet( - output_path=DATA_DIR + "train/", + train_dataset.to_parquet( + output_path=train_path, shuffle=nvt.io.Shuffle.PER_PARTITION, - cats=CATEGORICAL_COLUMNS, - labels=LABEL_COLUMNS, - dtypes=dict_dtypes, - ) - workflow.transform(valid_dataset).to_parquet( - output_path=DATA_DIR + "valid/", - shuffle=False, - cats=CATEGORICAL_COLUMNS, - labels=LABEL_COLUMNS, + cats=categorical_columns, + conts=dense_columns, + labels=["label"], dtypes=dict_dtypes, ) - # Train with HugeCTR - embeddings = get_embedding_sizes(workflow) + embeddings = {"a": (64, 16), "b": (64, 16), "c": (64, 16)} + total_cardinality = 0 slot_sizes = [] - for column in CATEGORICAL_COLUMNS: + + for column in cat_dtypes: slot_sizes.append(embeddings[column][0]) total_cardinality += embeddings[column][0] - test_data_path = DATA_DIR + "test/" - if path.exists(test_data_path): - shutil.rmtree(test_data_path) + # slot sizes = list of caridinalities per column, total is sum of individual + model = _run_model(slot_sizes, train_path, len(dense_columns)) - os.mkdir(test_data_path) + model_op = HugeCTR(model, max_nnz=2, device_list=[0]) - if path.exists(MODEL_DIR): - shutil.rmtree(MODEL_DIR) + model_repository_path = os.path.join(tmpdir, "model_repository") - os.makedirs(TRAIN_DIR) + input_schema = Schema( + [ + ColumnSchema("DES", dtype=np.float32), + ColumnSchema("CATCOLUMN", dtype=np.int64), + ColumnSchema("ROWINDEX", dtype=np.int32), + ] + ) + triton_chain = ColumnSelector(["DES", "CATCOLUMN", "ROWINDEX"]) >> model_op + ens = Ensemble(triton_chain, input_schema) + + os.makedirs(model_repository_path) + + enc_config, node_configs = ens.export(model_repository_path) + + assert enc_config + assert len(node_configs) == 1 + assert node_configs[0].name == "0_hugectr" + + df = train_dataset.to_ddf().compute()[:5] + dense, cats, rowptr = _convert(df, slot_sizes, categorical_columns, labels=["label"]) + + inputs = [ + grpcclient.InferInput("DES", dense.shape, triton.np_to_triton_dtype(dense.dtype)), + grpcclient.InferInput("CATCOLUMN", cats.shape, triton.np_to_triton_dtype(cats.dtype)), + grpcclient.InferInput("ROWINDEX", rowptr.shape, triton.np_to_triton_dtype(rowptr.dtype)), + ] + inputs[0].set_data_from_numpy(dense) + inputs[1].set_data_from_numpy(cats) + inputs[2].set_data_from_numpy(rowptr) + + response = _run_ensemble_on_tritonserver( + model_repository_path, + ["OUTPUT0"], + inputs, + "0_hugectr", + backend_config=f"hugectr,ps={tmpdir}/model_repository/ps.json", + ) + assert len(response.as_numpy("OUTPUT0")) == df.shape[0] - sample_data = cudf.read_parquet(DATA_DIR + "valid.parquet", num_rows=TEST_N_ROWS) - sample_data.to_csv(test_data_path + "data.csv") + model_config = node_configs[0].parameters["config"].string_value - sample_data_trans = nvt.workflow.workflow._transform_partition( - sample_data, [workflow.output_node] + hugectr_name = node_configs[0].name + dense_path = f"{tmpdir}/model_repository/{hugectr_name}/1/_dense_0.model" + sparse_files = [f"{tmpdir}/model_repository/{hugectr_name}/1/0_sparse_0.model"] + out_predict = _predict( + dense, cats, rowptr, model_config, hugectr_name, dense_path, sparse_files ) - dense_features, embedding_columns, row_ptrs = _convert(sample_data_trans, slot_sizes) - - model = _run_model(slot_sizes, total_cardinality) + np.testing.assert_array_almost_equal(response.as_numpy("OUTPUT0"), np.array(out_predict)) - model_op = HugeCTR(model) - config = model_op.export(tmpdir, None, None) - assert config is not None - # _predict(dense_features, embedding_columns, row_ptrs, tmpdir, MODEL_NAME) +def _predict( + dense_features, embedding_columns, row_ptrs, config_file, model_name, dense_path, sparse_paths +): + inference_params = InferenceParams( + model_name=model_name, + max_batchsize=64, + hit_rate_threshold=0.5, + dense_model_file=dense_path, + sparse_model_files=sparse_paths, + device_id=0, + use_gpu_embedding_cache=True, + cache_size_percentage=0.2, + i64_input_key=True, + use_mixed_precision=False, + ) + inference_session = CreateInferenceSession(config_file, inference_params) + output = inference_session.predict( + dense_features[0].tolist(), embedding_columns[0].tolist(), row_ptrs[0].tolist() + ) + return output diff --git a/tests/unit/systems/utils/triton.py b/tests/unit/systems/utils/triton.py index 2c84e1ba8..682e9e69b 100644 --- a/tests/unit/systems/utils/triton.py +++ b/tests/unit/systems/utils/triton.py @@ -28,15 +28,15 @@ def _run_ensemble_on_tritonserver( - tmpdir, - output_columns, - df, - model_name, + tmpdir, output_columns, df, model_name, backend_config="tensorflow,version=2" ): - inputs = triton.convert_df_to_triton_input(df.columns, df) + if not isinstance(df, list): + inputs = triton.convert_df_to_triton_input(df.columns, df) + else: + inputs = df outputs = [grpcclient.InferRequestedOutput(col) for col in output_columns] response = None - with run_triton_server(tmpdir) as client: + with run_triton_server(tmpdir, backend_config=backend_config) as client: response = client.infer(model_name, inputs, outputs=outputs) return response From b33549beb2413cad2981af6377e426ca674086ed Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Fri, 1 Jul 2022 08:08:33 -0400 Subject: [PATCH 04/10] add skip for module and add init --- tests/unit/systems/hugectr/__init__.py | 15 +++++++++++++++ tests/unit/systems/hugectr/test_hugectr.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/unit/systems/hugectr/__init__.py b/tests/unit/systems/hugectr/__init__.py index e69de29bb..0b8ff56d3 100644 --- a/tests/unit/systems/hugectr/__init__.py +++ b/tests/unit/systems/hugectr/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/systems/hugectr/test_hugectr.py b/tests/unit/systems/hugectr/test_hugectr.py index 837b794cc..6aea0cf23 100644 --- a/tests/unit/systems/hugectr/test_hugectr.py +++ b/tests/unit/systems/hugectr/test_hugectr.py @@ -16,7 +16,6 @@ import os -import cudf import numpy as np import pytest @@ -37,6 +36,7 @@ triton = pytest.importorskip("merlin.systems.triton") grpcclient = pytest.importorskip("tritonclient.grpc") +cudf = pytest.importorskip("cudf") # from common.parsers.benchmark_parsers import create_bench_result # from common.utils import _run_query From 1ce70c972db7986c21b4cd412365d2d260110ee8 Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Fri, 1 Jul 2022 08:10:50 -0400 Subject: [PATCH 05/10] remove common folder in tests and remove unneeded lines in test hugectr --- .../unit/common/parsers/benchmark_parsers.py | 178 ------------------ tests/unit/common/parsers/criteo_parsers.py | 139 -------------- tests/unit/common/parsers/rossmann_parsers.py | 77 -------- tests/unit/common/utils.py | 150 --------------- tests/unit/systems/hugectr/test_hugectr.py | 2 - 5 files changed, 546 deletions(-) delete mode 100644 tests/unit/common/parsers/benchmark_parsers.py delete mode 100644 tests/unit/common/parsers/criteo_parsers.py delete mode 100644 tests/unit/common/parsers/rossmann_parsers.py delete mode 100644 tests/unit/common/utils.py diff --git a/tests/unit/common/parsers/benchmark_parsers.py b/tests/unit/common/parsers/benchmark_parsers.py deleted file mode 100644 index 85ffbc62d..000000000 --- a/tests/unit/common/parsers/benchmark_parsers.py +++ /dev/null @@ -1,178 +0,0 @@ -# -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import datetime -import time - -from asvdb import BenchmarkResult - - -class Benchmark: - """ - Main general benchmark parsing class - """ - - def __init__(self, target_id, val=1, split=None): - self.name = f"{target_id}" - self.val = val - self.split = split - - def get_epoch(self, line): - raise NotImplementedError("Must Define logic for parsing metrics per epoch") - - def get_epochs(self, output): - raise NotImplementedError("Must Define logic for parsing output") - - -# Sub classes - - -class StandardBenchmark(Benchmark): - def get_info(self, output): - bench_infos = [] - losses = [] - for line in output: - if "run_time" in line: - bench_infos.append(line) - if "loss" in line: - losses.append(line) - loss_dict = {} - if losses: - loss_dict = {"loss": self.get_loss(losses[-1])} - if bench_infos: - bench_infos = self.get_dl_timing(bench_infos[-1:], optionals=loss_dict) - return bench_infos - - def get_dl_thru( - self, full_time, num_rows, epochs, throughput, optionals=None - ) -> BenchmarkResult: - metrics = [("thru", throughput), ("rows", num_rows), ("epochs", epochs)] - optionals = optionals or {} - for metric_name, metric_value in optionals.items(): - metrics.append((metric_name, metric_value)) - return create_bench_result( - f"{self.name}_dataloader", - metrics, - full_time, - "seconds", - ) - - def get_loss(self, line): - return float(line) - - def loss(self, epoch, loss, l_type="train") -> BenchmarkResult: - return create_bench_result( - f"{self.name}_{l_type}_loss", [("epoch", epoch)], loss, "percent" - ) - - def rmspe(self, epoch, rmspe) -> BenchmarkResult: - return create_bench_result(f"{self.name}_exp_rmspe", [("epoch", epoch)], rmspe, "percent") - - def acc(self, epoch, acc) -> BenchmarkResult: - return create_bench_result(f"{self.name}_exp_rmspe", [("epoch", epoch)], acc, "percent") - - def roc_auc(self, epoch, acc) -> BenchmarkResult: - return create_bench_result(f"{self.name}_exp_rmspe", [("epoch", epoch)], acc, "percent") - - def time(self, epoch, r_time, time_format="%M:%S") -> BenchmarkResult: - if time_format: - x = time.strptime(r_time.split(",")[0], time_format) - r_time = datetime.timedelta( - hours=x.tm_hour, minutes=x.tm_min, seconds=x.tm_sec - ).total_seconds() - return create_bench_result(f"{self.name}_time", [("epoch", epoch)], r_time, "seconds") - - def aps(self, epoch, aps) -> BenchmarkResult: - return create_bench_result(f"{self.name}_Avg_Prec", [("epoch", epoch)], aps, "percent") - - def get_dl_timing(self, output, optionals=None): - timing_res = [] - for line in output: - if line.startswith("run_time"): - run_time, num_rows, epochs, dl_thru = line.split(" - ") - run_time = float(run_time.split(": ")[1]) - num_rows = int(num_rows.split(": ")[1]) - epochs = int(epochs.split(": ")[1]) - dl_thru = float(dl_thru.split(": ")[1]) - bres = self.get_dl_thru( - run_time, num_rows * epochs, epochs, dl_thru, optionals=optionals - ) - timing_res.append(bres) - return timing_res[-1:] - - -class BenchFastAI(StandardBenchmark): - def __init__(self, target_id, val=6, split=None): - super().__init__(f"{target_id}_fastai", val=val, split=split) - - def get_epochs(self, output): - epochs = [] - for line in output: - split_line = line.split(self.split) if self.split else line.split() - if len(split_line) == self.val and is_whole_number(split_line[0]): - # epoch line, detected based on if 1st character is a number - post_evts = self.get_epoch(line) - epochs.append(post_evts) - if "run_time" in line: - epochs.append(self.get_dl_timing(line)) - return epochs[-1:] - - -# Utils - - -def is_whole_number(str_to_num): - try: - int(str_to_num) - return True - except ValueError: - return False - - -def is_float(str_to_flt): - try: - float(str_to_flt) - return True - except ValueError: - return False - - -def send_results(db, bench_info, results_list): - # only one entry because entries are split by Bench info - new_results_list = results_list - info_list = list(db.getInfo()) - if len(info_list) > 0: - br_list = db.getResults(filterInfoObjList=[bench_info]) - if br_list: - br_list = br_list[0][1] - results_to_remove = [] - for result in results_list: - if any(br.funcName == result.funcName for br in br_list): - results_to_remove.append(result) - new_results_list = [result for result in results_list if result not in results_to_remove] - # breakpoint() - for results in new_results_list: - if isinstance(results, list): - for result in results: - db.addResult(bench_info, result) - else: - db.addResult(bench_info, results) - - -def create_bench_result(name, arg_tuple_list, result, unit): - return BenchmarkResult( - funcName=name, argNameValuePairs=arg_tuple_list, unit=unit, result=result - ) diff --git a/tests/unit/common/parsers/criteo_parsers.py b/tests/unit/common/parsers/criteo_parsers.py deleted file mode 100644 index a5b23ceb8..000000000 --- a/tests/unit/common/parsers/criteo_parsers.py +++ /dev/null @@ -1,139 +0,0 @@ -# -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import re - -from tests.integration.common.parsers.benchmark_parsers import ( - BenchFastAI, - StandardBenchmark, - create_bench_result, -) - -decimal_regex = "[0-9]+\.?[0-9]*|\.[0-9]+" # noqa pylint: disable=W1401 - - -class CriteoBenchFastAI(BenchFastAI): - def __init__(self, name="CriteoFastAI", val=6, split=None): - self.name = name - self.val = val - self.split = split - - def get_info(self, output): - bench_infos = [] - losses = [] - for line in output: - if "run_time" in line: - bench_infos.append(line) - if "loss" in line and "Train" in line and "Valid" in line: - losses.append(line) - loss_dict = {} - if losses: - for loss in losses: - t_loss, v_loss = self.get_loss(loss) - loss_dict["loss_train"] = t_loss - loss_dict["loss_valid"] = v_loss - if bench_infos: - bench_infos = self.get_dl_timing(bench_infos[-1:], optionals=loss_dict) - return bench_infos - - def get_epoch(self, line): - epoch, t_loss, v_loss, roc, aps, o_time = line.split() - t_loss = self.loss(epoch, float(t_loss)) - v_loss = self.loss(epoch, float(v_loss), l_type="valid") - roc = self.roc_auc(epoch, float(roc)) - aps = self.aps(epoch, float(aps)) - return [t_loss, v_loss, roc, aps, o_time] - - def get_loss(self, line): - epoch, t_loss, v_loss, roc, aps, o_time = line.split() - t_loss = float(t_loss) - v_loss = float(v_loss) - return [t_loss, v_loss] - - -class CriteoBenchHugeCTR(StandardBenchmark): - def __init__(self, name="CriteoHugeCTR"): - self.name = name - - def get_epochs(self, output): - aucs = [] - for line in output: - if "AUC:" in line: - auc_num = float(line.split("AUC:")[-1]) - aucs.append(auc_num) - if "run_time:" in line: - run_time = self.get_runtime(line) - if run_time and aucs: - return self.get_epoch(max(aucs), run_time) - return [] - - def get_runtime(self, line): - split_line = line.split(":") - return float(split_line[1]) - - def get_epoch(self, auc, runtime): - bres_auc = create_bench_result(f"{self.name}_auc", [("time", runtime)], auc, "percent") - return [bres_auc] - - -class CriteoTensorflow(StandardBenchmark): - def __init__(self, name="CriteoTensorFlow"): - self.name = name - - def get_loss(self, line): - loss = line.split("-")[-1] - loss = loss.split(":")[-1] - losses = re.findall(decimal_regex, loss) - losses = losses or [] - return float(losses[-1]) - - -class CriteoTorch(StandardBenchmark): - def __init__(self, name="CriteoTorch"): - self.name = name - - def get_info(self, output): - bench_infos = [] - losses = [] - for line in output: - if "run_time" in line: - bench_infos.append(line) - if "loss" in line and "Train" in line and "Valid" in line: - losses.append(line) - loss_dict = {} - if losses: - for idx, loss in enumerate(losses): - t_loss, v_loss = self.get_loss(loss) - loss_dict["loss_train"] = t_loss - loss_dict["loss_valid"] = v_loss - if bench_infos: - bench_infos = self.get_dl_timing(bench_infos[-1:], optionals=loss_dict) - return bench_infos - - def get_loss(self, line): - # Epoch 00. Train loss: 0.1944. Valid loss: 0.1696. - loss_parse = line.split(". ") - epoch = loss_parse[0].split(" ")[-1] - train_loss = loss_parse[1].split(":")[-1] - valid_loss = loss_parse[2].split(":")[-1] - - epoch = re.findall(decimal_regex, epoch)[-1] - train_loss = re.findall(decimal_regex, train_loss)[-1] - valid_loss = re.findall(decimal_regex, valid_loss)[-1] - - epoch = int(epoch) - train_loss = float(train_loss) - valid_loss = float(valid_loss) - return [train_loss, valid_loss] diff --git a/tests/unit/common/parsers/rossmann_parsers.py b/tests/unit/common/parsers/rossmann_parsers.py deleted file mode 100644 index 7538065f8..000000000 --- a/tests/unit/common/parsers/rossmann_parsers.py +++ /dev/null @@ -1,77 +0,0 @@ -# -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from tests.integration.common.parsers.benchmark_parsers import BenchFastAI, StandardBenchmark - - -class RossBenchTensorFlow(StandardBenchmark): - def __init__(self, split=" - "): - super().__init__("Rossmann_tf", split=split) - - def get_epoch(self, line, epoch=0): - _, _, t_loss, t_rmspe = line.split(self.split) - t_loss = self.loss(epoch, float(t_loss.split(": ")[1])) - # t_rmspe = self.rmspe(epoch, float(t_rmspe.split(": ")[1])) - return [t_loss, t_rmspe] - - def get_epochs(self, output): - epochs = [] - for idx, line in enumerate(output): - if "Epoch" in line: - epoch = int(line.split()[-1].split("/")[0]) - # output skips line for formatting and remove returns (\x08) - content_line = output[idx + 2].rstrip("\x08") - # epoch line, detected based on if 1st character is a number - post_evts = self.get_epoch(content_line, epoch=epoch) - epochs.append(post_evts) - if "run_time" in line: - epochs.append(self.get_dl_timing(line)) - return epochs[-1:] - - -class RossBenchPytorch(StandardBenchmark): - def __init__(self, split=". "): - super().__init__("Rossmann_torch", split=split) - - def get_epoch(self, line): - epoch, t_loss, t_rmspe, v_loss, v_rmspe = line.split(self.split) - epoch = epoch.split()[1] - t_loss = self.loss(epoch, float(t_loss.split(": ")[1])) - v_loss = self.loss(epoch, float(v_loss.split(": ")[1]), l_type="valid") - return [t_loss, v_loss, t_rmspe, v_rmspe] - - def get_epochs(self, output): - epochs = [] - for line in output: - if "Epoch" in line: - # epoch line, detected based on if 1st character is a number - post_evts = self.get_epoch(line) - epochs.append(post_evts) - if "run_time" in line: - epochs.append(self.get_dl_timing(line)) - return epochs[-1:] - - -class RossBenchFastAI(BenchFastAI): - def __init__(self, val=5, split=None): - super().__init__("Rossmann", val=val, split=split) - - def get_epoch(self, line): - epoch, t_loss, v_loss, exp_rmspe, o_time = line.split() - t_loss = self.loss(epoch, float(t_loss)) - v_loss = self.loss(epoch, float(v_loss), l_type="valid") - # exp_rmspe = self.rmspe(epoch, float(exp_rmspe)) - return [t_loss, v_loss, exp_rmspe, o_time] diff --git a/tests/unit/common/utils.py b/tests/unit/common/utils.py deleted file mode 100644 index 9f2271457..000000000 --- a/tests/unit/common/utils.py +++ /dev/null @@ -1,150 +0,0 @@ -# -# Copyright (c) 2021, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import datetime as dt -import itertools -import json -import os -import shutil -import subprocess -import sys - -import cudf -import cupy as cp - -import nvtabular as nvt - - -def _run_notebook( - tmpdir, - notebook_path, - input_path, - output_path, - batch_size=None, - gpu_id=0, - clean_up=True, - transform=None, - params=None, - main_block=-1, -): - params = params or [] - - os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("GPU_TARGET_ID", gpu_id) - - if not os.path.exists(input_path): - os.makedirs(input_path) - if not os.path.exists(output_path): - os.makedirs(output_path) - if batch_size: - os.environ["BATCH_SIZE"] = os.environ.get("BATCH_SIZE", batch_size) - - os.environ["INPUT_DATA_DIR"] = input_path - os.environ["OUTPUT_DATA_DIR"] = output_path - # read in the notebook as JSON, and extract a python script from it - notebook = json.load(open(notebook_path, encoding="utf-8")) - source_cells = [cell["source"] for cell in notebook["cells"] if cell["cell_type"] == "code"] - - lines = [ - transform(line.rstrip()) if transform else line - for line in itertools.chain(*source_cells) - if not (line.startswith("%") or line.startswith("!")) - ] - - # Replace config params - if params: - - def transform_fracs(line): - line = line.replace("device_limit_frac = 0.7", "device_limit_frac = " + str(params[0])) - line = line.replace("device_pool_frac = 0.8", "device_pool_frac = " + str(params[1])) - return line.replace("part_mem_frac = 0.15", "part_mem_frac = " + str(params[2])) - - lines = [transform_fracs(line) for line in lines] - - # Add guarding block and indentation - if main_block >= 0: - lines.insert(main_block, 'if __name__ == "__main__":') - for i in range(main_block + 1, len(lines)): - lines[i] = " " + lines[i] - - # save the script to a file, and run with the current python executable - # we're doing this in a subprocess to avoid some issues using 'exec' - # that were causing a segfault with globals of the exec'ed function going - # out of scope - script_path = os.path.join(tmpdir, "notebook.py") - with open(script_path, "w") as script: - script.write("\n".join(lines)) - output = subprocess.check_output([sys.executable, script_path]) - # save location will default to run location - output = output.decode("utf-8") - _, note_name = os.path.split(notebook_path) - note_name = note_name.split(".")[0] - if output: - with open(f"test_res_{note_name}", "w+") as w_file: - w_file.write(output) - # clear out products - if clean_up: - shutil.rmtree(output_path) - return output - - -def _run_query( - client, - n_rows, - model_name, - workflow_path, - data_path, - actual_output_filename, - output_name, - input_cols_name=None, - backend="tensorflow", -): - - import tritonclient.grpc as grpcclient - from tritonclient.utils import np_to_triton_dtype - - workflow = nvt.Workflow.load(workflow_path) - - if input_cols_name is None: - batch = cudf.read_csv(data_path, nrows=n_rows)[workflow.output_node.input_columns.names] - else: - batch = cudf.read_csv(data_path, nrows=n_rows)[input_cols_name] - - input_dtypes = workflow.input_dtypes - columns = [(col, batch[col]) for col in batch.columns] - - inputs = [] - for i, (name, col) in enumerate(columns): - d = col.values_host.astype(input_dtypes[name]) - d = d.reshape(len(d), 1) - inputs.append(grpcclient.InferInput(name, d.shape, np_to_triton_dtype(input_dtypes[name]))) - inputs[i].set_data_from_numpy(d) - - outputs = [grpcclient.InferRequestedOutput(output_name)] - time_start = dt.datetime.now() - response = client.infer(model_name, inputs, request_id="1", outputs=outputs) - run_time = dt.datetime.now() - time_start - - output_key = "output" if backend == "hugectr" else "0" - - output_actual = cudf.read_csv(os.path.expanduser(actual_output_filename), nrows=n_rows) - output_actual = cp.asnumpy(output_actual[output_key].values) - output_predict = response.as_numpy(output_name) - - if backend == "tensorflow": - output_predict = output_predict[:, 0] - - diff = abs(output_actual - output_predict) - return diff, run_time diff --git a/tests/unit/systems/hugectr/test_hugectr.py b/tests/unit/systems/hugectr/test_hugectr.py index 6aea0cf23..6392f0a0d 100644 --- a/tests/unit/systems/hugectr/test_hugectr.py +++ b/tests/unit/systems/hugectr/test_hugectr.py @@ -37,8 +37,6 @@ triton = pytest.importorskip("merlin.systems.triton") grpcclient = pytest.importorskip("tritonclient.grpc") cudf = pytest.importorskip("cudf") -# from common.parsers.benchmark_parsers import create_bench_result -# from common.utils import _run_query def _run_model(slot_sizes, source, dense_dim): From 8cbaf90acd163814e3ab30bf6278638fcb1fc3a5 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 22 Aug 2022 17:02:57 +0100 Subject: [PATCH 06/10] Update formatting based on lint suggestions --- merlin/systems/dag/ops/hugectr.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py index 2d5172613..1f93572cb 100644 --- a/merlin/systems/dag/ops/hugectr.py +++ b/merlin/systems/dag/ops/hugectr.py @@ -151,7 +151,8 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): self.model.graph_to_json(graph_config_file=network_file) self.model.save_params_to_files(str(hugectr_model_path) + "/") - model_json = json.loads(open(network_file, "r").read()) + with open(network_file, "r", encoding="utf-8") as f: + model_json = json.loads(f.read()) dense_pattern = "*_dense_*.model" dense_path = [ os.path.join(hugectr_model_path, path.name) @@ -165,7 +166,7 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): if "opt" not in path.name ] - config_dict = dict() + config_dict = {} config_dict["supportlonglong"] = True data_layer = model_json["layers"][0] @@ -177,7 +178,7 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"]) vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers] - model = dict() + model = {} model["model"] = model_name model["slot_num"] = num_cat_columns model["sparse_files"] = sparse_paths @@ -201,7 +202,7 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): config_dict["models"] = [model] parameter_server_config_path = str(node_export_path.parent / "ps.json") - with open(parameter_server_config_path, "w") as f: + with open(parameter_server_config_path, "w", encoding="utf-8") as f: f.write(json.dumps(config_dict)) self.hugectr_params["config"] = network_file @@ -215,7 +216,7 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size) - with open(os.path.join(node_export_path, "config.pbtxt"), "w") as o: + with open(os.path.join(node_export_path, "config.pbtxt"), "w", encoding="utf-8") as o: text_format.PrintMessage(config, o) return config From 3b3f1d774d0edbf58e6edaee9556f582a0aff3d7 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 22 Aug 2022 17:03:18 +0100 Subject: [PATCH 07/10] Update _hugectr_config to be more concise --- merlin/systems/dag/ops/hugectr.py | 81 ++++++++++--------------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py index 1f93572cb..ef8b8b3a7 100644 --- a/merlin/systems/dag/ops/hugectr.py +++ b/merlin/systems/dag/ops/hugectr.py @@ -16,6 +16,7 @@ import json import os import pathlib +from typing import Optional import numpy as np import tritonclient.grpc.model_config_pb2 as model_config @@ -222,7 +223,9 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1): return config -def _hugectr_config(name, hugectr_params, max_batch_size=None): +def _hugectr_config( + name: str, parameters: dict, max_batch_size: Optional[int] = None +) -> model_config.ModelConfig: """Create a config for a HugeCTR model. Parameters @@ -239,63 +242,29 @@ def _hugectr_config(name, hugectr_params, max_batch_size=None): config Dictionary representation of hugectr config. """ - config = model_config.ModelConfig(name=name, backend="hugectr", max_batch_size=max_batch_size) - - config.input.append( - model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]) - ) - - config.input.append( - model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1]) - ) - - config.input.append( - model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1]) + config = model_config.ModelConfig( + name=name, + backend="hugectr", + max_batch_size=max_batch_size, + input=[ + model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]), + model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1]), + model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1]), + ], + output=[ + model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1]) + ], + instance_group=[model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)], ) - config.output.append( - model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1]) - ) - - config.instance_group.append(model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)) - - config_hugectr = model_config.ModelParameter(string_value=hugectr_params["config"]) - config.parameters["config"].CopyFrom(config_hugectr) - - gpucache_val = hugectr_params["gpucache"] - gpucache = model_config.ModelParameter(string_value=gpucache_val) - config.parameters["gpucache"].CopyFrom(gpucache) - - gpucacheper_val = str(hugectr_params["gpucacheper"]) - gpucacheper = model_config.ModelParameter(string_value=gpucacheper_val) - config.parameters["gpucacheper"].CopyFrom(gpucacheper) - - label_dim = model_config.ModelParameter(string_value=str(hugectr_params["label_dim"])) - config.parameters["label_dim"].CopyFrom(label_dim) - - slots = model_config.ModelParameter(string_value=str(hugectr_params["slots"])) - config.parameters["slots"].CopyFrom(slots) - - des_feature_num = model_config.ModelParameter( - string_value=str(hugectr_params["des_feature_num"]) - ) - config.parameters["des_feature_num"].CopyFrom(des_feature_num) - - cat_feature_num = model_config.ModelParameter( - string_value=str(hugectr_params["cat_feature_num"]) - ) - config.parameters["cat_feature_num"].CopyFrom(cat_feature_num) - - max_nnz = model_config.ModelParameter(string_value=str(hugectr_params["max_nnz"])) - config.parameters["max_nnz"].CopyFrom(max_nnz) - - embedding_vector_size = model_config.ModelParameter( - string_value=str(hugectr_params["embedding_vector_size"]) - ) - config.parameters["embedding_vector_size"].CopyFrom(embedding_vector_size) + for parameter_key, parameter_value in parameters.items(): + if parameter_value is None: + continue - embeddingkey_long_type_val = hugectr_params["embeddingkey_long_type"] - embeddingkey_long_type = model_config.ModelParameter(string_value=embeddingkey_long_type_val) - config.parameters["embeddingkey_long_type"].CopyFrom(embeddingkey_long_type) + if isinstance(parameter_value, list): + config.parameters[parameter_key].string_value = json.dumps(parameter_value) + elif isinstance(parameter_value, bool): + config.parameters[parameter_key].string_value = str(parameter_value).lower() + config.parameters[parameter_key].string_value = str(parameter_value) return config From c923a2742cd62864ea165e714a3c9dae9326f6eb Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 22 Aug 2022 17:03:37 +0100 Subject: [PATCH 08/10] Add params argument to export method signature --- merlin/systems/dag/ops/hugectr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py index ef8b8b3a7..e4f3d1e4c 100644 --- a/merlin/systems/dag/ops/hugectr.py +++ b/merlin/systems/dag/ops/hugectr.py @@ -120,7 +120,7 @@ def compute_output_schema( """ return Schema([ColumnSchema("OUTPUT0", dtype=np.float32)]) - def export(self, path, input_schema, output_schema, node_id=None, version=1): + def export(self, path, input_schema, output_schema, node_id=None, params=None, version=1): """Create and export the required config files for the hugectr model. Parameters From 4d99847a4d45afb83050acc2c99235edc09ac0eb Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 22 Aug 2022 17:03:56 +0100 Subject: [PATCH 09/10] Add slot_sizes parameter --- merlin/systems/dag/ops/hugectr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py index e4f3d1e4c..97ee0e4d0 100644 --- a/merlin/systems/dag/ops/hugectr.py +++ b/merlin/systems/dag/ops/hugectr.py @@ -176,6 +176,7 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v for layer in model_json["layers"] if layer["type"] == "DistributedSlotSparseEmbeddingHash" ] + full_slots = [x["sparse_embedding_hparam"]["slot_size_array"] for x in sparse_layers] num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"]) vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers] @@ -214,7 +215,7 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v self.hugectr_params["embedding_vector_size"] = vec_size[0] self.hugectr_params["slots"] = num_cat_columns self.hugectr_params["label_dim"] = data_layer["label"]["label_dim"] - + self.hugectr_params["slot_sizes"] = full_slots config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size) with open(os.path.join(node_export_path, "config.pbtxt"), "w", encoding="utf-8") as o: From 027f495e62b6030a2cd712f532280b54c3b54a5a Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 23 Aug 2022 12:28:49 +0100 Subject: [PATCH 10/10] Extract config to methods and extend with all known params --- merlin/systems/dag/ops/hugectr.py | 369 +++++++++++++++++++++--------- 1 file changed, 257 insertions(+), 112 deletions(-) diff --git a/merlin/systems/dag/ops/hugectr.py b/merlin/systems/dag/ops/hugectr.py index 97ee0e4d0..160a6caeb 100644 --- a/merlin/systems/dag/ops/hugectr.py +++ b/merlin/systems/dag/ops/hugectr.py @@ -16,7 +16,7 @@ import json import os import pathlib -from typing import Optional +from typing import List, Optional, Union import numpy as np import tritonclient.grpc.model_config_pb2 as model_config @@ -28,35 +28,122 @@ class HugeCTR(InferenceOperator): - """ - Creates an operator meant to house a HugeCTR model. - Allows the model to run as part of a merlin graph operations for inference. + """This operator takes a HugeCTR model and packages it correctly for tritonserver + to run, on the hugectr backend. """ def __init__( self, model, - max_batch_size=64, - device_list=None, - hit_rate_threshold=None, - gpucache=None, - freeze_sparse=None, - gpucacheper=None, - max_nnz=2, - embeddingkey_long_type=None, + *, + device_list: Optional[List[int]] = None, + max_batch_size: int = 64, + gpucache: Optional[bool] = None, + hit_rate_threshold: Optional[float] = None, + gpucacheper: Optional[float] = None, + use_mixed_precision: Optional[bool] = None, + scaler: Optional[float] = None, + use_algorithm_search: Optional[bool] = None, + use_cuda_graph: Optional[bool] = None, + num_of_worker_buffer_in_pool: Optional[int] = None, + num_of_refresher_buffer_in_pool: Optional[int] = None, + cache_refresh_percentage_per_iteration: Optional[float] = None, + default_value_for_each_table: float = 0.0, + refresh_delay: Optional[float] = None, + refresh_interval: Optional[float] = None, + freeze_sparse: Optional[bool] = None, + max_nnz: Optional[int] = None, + embeddingkey_long_type: Optional[bool] = None, + supportlonglong: Optional[bool] = None, + persistent_db: Optional[dict] = None, + volatile_db: Optional[dict] = None, + update_source: Optional[dict] = None, ): + """ + Parameters + ---------- + model : hugectr.Model, required + A hugeCTR model instance. + device_list : List[int] + Indicate the list of devices used to deploy the + Hierarchical Parameter Server (HPS). The default is an + empty list. + max_batch_size : int + The maximum batch size to be processed per batch, by an + inference request + gpucache : bool + Use this option to enable the GPU embedding cache mechanism. + hit_rate_threshold : float + Determines the insertion mechanism of the embedding cache + and Parameter Server based on the hit rate. + gpucacheper : float + Determines what percentage of the embedding vectors will + be loaded from the embedding table into the GPU embedding + cache. + use_mixed_precision: bool + Determines if mixed precision will be used. + scaler : float + Scaler for parameter server model config. + use_algorithm_search : bool + Determines if algorithm search will be used. + use_cuda_graph : bool + Determines if cuda graph will be used. + num_of_worker_buffer_in_pool : int + Specifies number of worker buffers in pool. + num_of_refresher_buffer_in_pool : int + Specifies number of refresher buffers in pool. + cache_refresh_percentage_per_iteration : float + The percentage of the cache to refresh each iteration. + default_value_for_each_table : float + The default value to use for each embedding table. + refresh_delay : float + Model refresh delay + refresh_interval : float + Model refresh interval + freeze_sparse : bool + Option to keep sparse tables from being updated. + This is useful when using online updates if you wish + to disable repeaded updates to these embedding tables. + max_nnz : int + Maximum NNZ + supportlonglong : bool + Parameter server config. Specifies if longlong is supported. + persistent_db : dict, optional + Configuration for persistent database. + Supports RocsDB. + volatile_db : dict, optional + configuration for Volatile database. Allows utilizing + Redis cluster deployments, to store and retrieve + embeddings in/from the RAM memory available in your + cluster. + update_source : dict, optional + Configuration of real-time update source for model + updates. Supports Apache Kafka. + """ self.model = model self.max_batch_size = max_batch_size self.device_list = device_list or [] - embeddingkey_long_type = embeddingkey_long_type or "true" - gpucache = gpucache or "true" - gpucacheper = gpucacheper or 0.5 - - self.hugectr_params = dict( - hit_rate_threshold=hit_rate_threshold, - gpucache=gpucache, + self.hit_rate_threshold = hit_rate_threshold + self.gpucache = gpucache + self.gpucacheper = gpucacheper + self.use_mixed_precision = use_mixed_precision + self.scaler = scaler + self.use_algorithm_search = use_algorithm_search + self.use_cuda_graph = use_cuda_graph + self.num_of_worker_buffer_in_pool = num_of_worker_buffer_in_pool + self.num_of_refresher_buffer_in_pool = num_of_refresher_buffer_in_pool + self.cache_refresh_percentage_per_iteration = cache_refresh_percentage_per_iteration + self.default_value_for_each_table = default_value_for_each_table + self.refresh_delay = refresh_delay + self.refresh_interval = refresh_interval + self.supportlonglong = supportlonglong + self.persistent_db = persistent_db + self.volatile_db = volatile_db + self.update_source = update_source + + # These params will be set as parameters in the triton model config. + self.model_config_params = dict( freeze_sparse=freeze_sparse, - gpucacheper=gpucacheper, max_nnz=max_nnz, embeddingkey_long_type=embeddingkey_long_type, ) @@ -70,7 +157,7 @@ def compute_input_schema( deps_schema: Schema, selector: ColumnSelector, ): - """_summary_ + """Return the input schema for this operator. Parameters ---------- @@ -133,6 +220,8 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v Schema describing outputs of model node_id : int, optional The node's position in execution chain, by default None + params : string, optional + Parameters dictionary of key, value pairs stored in exported config, by default None. version : int, optional The version of the model, by default 1 @@ -144,128 +233,184 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v node_name = f"{node_id}_{self.export_name}" if node_id is not None else self.export_name node_export_path = pathlib.Path(path) / node_name node_export_path.mkdir(exist_ok=True) + model_path = pathlib.Path(node_export_path) / str(version) + model_path.mkdir(exist_ok=True) model_name = node_name - hugectr_model_path = pathlib.Path(node_export_path) / str(version) - hugectr_model_path.mkdir(exist_ok=True) - - network_file = os.path.join(hugectr_model_path, f"{model_name}.json") + # Write model files + network_file = os.path.join(model_path, f"{model_name}.json") self.model.graph_to_json(graph_config_file=network_file) - self.model.save_params_to_files(str(hugectr_model_path) + "/") - with open(network_file, "r", encoding="utf-8") as f: - model_json = json.loads(f.read()) + self.model.save_params_to_files(str(model_path) + "/") + + # Write parameter server configuration + # TODO: support multiple models in same ensemble. + # parameter server config will need to be centralized and + # combine the models from more than one operator. + model = self._get_ps_model_config(model_path, model_name) + parameter_server_config = { + "models": [model], + "supportlonglong": self.supportlonglong, + } + if self.persistent_db: + parameter_server_config["peristent_db"] = self.persistent_db + if self.volatile_db: + parameter_server_config["volatile_db"] = self.volatile_db + if self.update_source: + parameter_server_config["update_source"] = self.update_source + parameter_server_config_path = str(node_export_path.parent / "ps.json") + with open(parameter_server_config_path, "w", encoding="utf-8") as f: + f.write(json.dumps(parameter_server_config)) + + # Write triton model config + model_config_params = {**self.model_config_params, "network_file": network_file} + config = self._get_model_config(node_name, model_config_params) + with open(os.path.join(node_export_path, "config.pbtxt"), "w", encoding="utf-8") as o: + text_format.PrintMessage(config, o) + + return config + + def _get_ps_model_config(self, model_path: Union[str, os.PathLike], model_name: str): + """Get HugeCTR model config for parameter server. + + Parameters + ---------- + model_path : str + directory containing the exported model files. + model_name : str + The name of the model. A file of the name + .json is expected to be located in the model + path provided. + """ + model_path = pathlib.Path(model_path) + + network_file = model_path / f"{model_name}.json" + + # find paths to dense and sparse models dense_pattern = "*_dense_*.model" dense_path = [ - os.path.join(hugectr_model_path, path.name) - for path in hugectr_model_path.glob(dense_pattern) + str(model_path / path.name) + for path in model_path.glob(dense_pattern) if "opt" not in path.name ][0] sparse_pattern = "*_sparse_*.model" sparse_paths = [ - os.path.join(hugectr_model_path, path.name) - for path in hugectr_model_path.glob(sparse_pattern) + str(model_path / path.name) + for path in model_path.glob(sparse_pattern) if "opt" not in path.name ] - config_dict = {} - config_dict["supportlonglong"] = True - + # find layers in model network file + with open(network_file, "r", encoding="utf-8") as f: + model_json = json.loads(f.read()) data_layer = model_json["layers"][0] sparse_layers = [ layer for layer in model_json["layers"] if layer["type"] == "DistributedSlotSparseEmbeddingHash" ] - full_slots = [x["sparse_embedding_hparam"]["slot_size_array"] for x in sparse_layers] - num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"]) - vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers] model = {} model["model"] = model_name - model["slot_num"] = num_cat_columns - model["sparse_files"] = sparse_paths - model["dense_file"] = dense_path - model["maxnum_des_feature_per_sample"] = data_layer["dense"]["dense_dim"] model["network_file"] = network_file - model["num_of_worker_buffer_in_pool"] = 4 - model["num_of_refresher_buffer_in_pool"] = 1 - model["deployed_device_list"] = self.device_list model["max_batch_size"] = self.max_batch_size - model["default_value_for_each_table"] = [0.0] * len(sparse_layers) - model["hit_rate_threshold"] = 0.9 - model["gpucacheper"] = self.hugectr_params["gpucacheper"] - model["gpucache"] = True - model["cache_refresh_percentage_per_iteration"] = 0.2 + model["dense_file"] = dense_path + model["sparse_files"] = sparse_paths + model["gpucache"] = self.gpucache + model["hit_rate_threshold"] = self.hit_rate_threshold + model["gpucacheper"] = self.gpucacheper + model["use_mixed_precision"] = self.use_mixed_precision + model["scaler"] = self.scaler + model["use_algorithm_search"] = self.use_algorithm_search + model["use_cuda_graph"] = self.use_cuda_graph + model["num_of_worker_buffer_in_pool"] = self.num_of_worker_buffer_in_pool + model["num_of_refresher_buffer_in_pool"] = self.num_of_refresher_buffer_in_pool + model[ + "cache_refresh_percentage_per_iteration" + ] = self.cache_refresh_percentage_per_iteration + model["deployed_device_list"] = self.device_list + model["default_value_for_each_table"] = [self.default_value_for_each_table] * len( + sparse_layers + ) + # each sample may contain a varying number of numeric (dense) + # features. this configures the value of the maximum number + # of dense features in each sample, which determines the + # pre-allocated memory size on the host and device. + model["maxnum_des_feature_per_sample"] = data_layer["dense"]["dense_dim"] + model["refresh_delay"] = self.refresh_delay + model["refresh_interval"] = self.refresh_interval + # This determines the pre-allocated memory size on the host and device. + # We assume that for each input sample, there is a maximum + # number of embedding keys per sample in each embedding table + # that need to be looked up, so the user needs to configure + # the [ Maximum(the number of embedding keys that need to be + # queried from embedding table 1 in each sample), Maximum(the + # number of embedding keys that need to be queried from + # embedding table 2 in each sample), ...] in this item. model["maxnum_catfeature_query_per_table_per_sample"] = [ len(x["sparse_embedding_hparam"]["slot_size_array"]) for x in sparse_layers ] - model["embedding_vecsize_per_table"] = vec_size + model["embedding_vecsize_per_table"] = [ + x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers + ] model["embedding_table_names"] = [x["top"] for x in sparse_layers] - config_dict["models"] = [model] + model["label_dim"] = data_layer["label"]["label_dim"] + model["slot_num"] = sum(x["slot_num"] for x in data_layer["sparse"]) - parameter_server_config_path = str(node_export_path.parent / "ps.json") - with open(parameter_server_config_path, "w", encoding="utf-8") as f: - f.write(json.dumps(config_dict)) + # remove unset (None) values + model = {k: v for k, v in model.items() if v is not None} - self.hugectr_params["config"] = network_file + return model - # These are no longer required from hugectr_backend release 3.7 - self.hugectr_params["cat_feature_num"] = num_cat_columns - self.hugectr_params["des_feature_num"] = data_layer["dense"]["dense_dim"] - self.hugectr_params["embedding_vector_size"] = vec_size[0] - self.hugectr_params["slots"] = num_cat_columns - self.hugectr_params["label_dim"] = data_layer["label"]["label_dim"] - self.hugectr_params["slot_sizes"] = full_slots - config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size) + def _get_model_config(self, name: str, parameters: dict) -> model_config.ModelConfig: + """Returns a ModelConfig for a HugeCTR model. - with open(os.path.join(node_export_path, "config.pbtxt"), "w", encoding="utf-8") as o: - text_format.PrintMessage(config, o) + Parameters + ---------- + name : string + The name of the triton model. This should match the name + of the directory where the model is exported. + parameters : dict + Dictionary holding parameter values for the model configuration. - return config + Returns + ------- + config + Dictionary representation of hugectr config. + """ + config = model_config.ModelConfig( + name=name, + backend="hugectr", + max_batch_size=self.max_batch_size, + input=[ + model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]), + model_config.ModelInput( + name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1] + ), + model_config.ModelInput( + name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1] + ), + ], + output=[ + model_config.ModelOutput( + name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1] + ) + ], + instance_group=[ + model_config.ModelInstanceGroup( + gpus=self.device_list, + count=len(self.device_list), + kind=model_config.ModelInstanceGroup.Kind.KIND_GPU, + ) + ], + ) + for parameter_key, parameter_value in parameters.items(): + if parameter_value is None: + continue + if isinstance(parameter_value, list): + config.parameters[parameter_key].string_value = json.dumps(parameter_value) + elif isinstance(parameter_value, bool): + config.parameters[parameter_key].string_value = str(parameter_value).lower() + config.parameters[parameter_key].string_value = str(parameter_value) -def _hugectr_config( - name: str, parameters: dict, max_batch_size: Optional[int] = None -) -> model_config.ModelConfig: - """Create a config for a HugeCTR model. - - Parameters - ---------- - name : string - The name of the hugectr model. - hugectr_params : dictionary - Dictionary holding parameter values required by hugectr - max_batch_size : int, optional - The maximum batch size to be processed per batch, by an inference request, by default None - - Returns - ------- - config - Dictionary representation of hugectr config. - """ - config = model_config.ModelConfig( - name=name, - backend="hugectr", - max_batch_size=max_batch_size, - input=[ - model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]), - model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1]), - model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1]), - ], - output=[ - model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1]) - ], - instance_group=[model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)], - ) - - for parameter_key, parameter_value in parameters.items(): - if parameter_value is None: - continue - - if isinstance(parameter_value, list): - config.parameters[parameter_key].string_value = json.dumps(parameter_value) - elif isinstance(parameter_value, bool): - config.parameters[parameter_key].string_value = str(parameter_value).lower() - config.parameters[parameter_key].string_value = str(parameter_value) - - return config + return config