Skip to content

Commit

Permalink
Add slot_sizes parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Aug 23, 2022
1 parent c923a27 commit 4d99847
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion merlin/systems/dag/ops/hugectr.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v
for layer in model_json["layers"]
if layer["type"] == "DistributedSlotSparseEmbeddingHash"
]
full_slots = [x["sparse_embedding_hparam"]["slot_size_array"] for x in sparse_layers]
num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"])
vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers]

Expand Down Expand Up @@ -214,7 +215,7 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v
self.hugectr_params["embedding_vector_size"] = vec_size[0]
self.hugectr_params["slots"] = num_cat_columns
self.hugectr_params["label_dim"] = data_layer["label"]["label_dim"]

self.hugectr_params["slot_sizes"] = full_slots
config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size)

with open(os.path.join(node_export_path, "config.pbtxt"), "w", encoding="utf-8") as o:
Expand Down

0 comments on commit 4d99847

Please sign in to comment.