diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index 8c802b4686..120b91a771 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -64,8 +64,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] @@ -187,48 +185,48 @@ "
\n", "1686 rows × 6 columns
\n", + "1718 rows × 6 columns
\n", "" ], "text/plain": [ " session_id item_id-count \\\n", - "0 1 19 \n", - "1 17 13 \n", - "2 34 13 \n", - "4 58 12 \n", - "5 64 12 \n", + "0 6 14 \n", + "1 9 14 \n", + "2 14 13 \n", + "4 39 12 \n", + "5 52 12 \n", "... ... ... \n", - "2110 19074 2 \n", - "2111 19122 2 \n", - "2112 19128 2 \n", - "2113 19134 2 \n", - "2114 19136 2 \n", + "2145 19158 2 \n", + "2146 19165 2 \n", + "2148 19183 2 \n", + "2149 19199 2 \n", + "2150 19221 2 \n", "\n", " item_id-list \\\n", - "0 [27, 26, 7, 46, 13, 2, 4, 237, 10, 35, 46, 35,... \n", - "1 [15, 5, 5, 58, 8, 18, 29, 34, 2, 3, 43, 54, 9] \n", - "2 [17, 12, 9, 21, 29, 6, 23, 6, 5, 176, 12, 26, 1] \n", - "4 [84, 11, 7, 66, 23, 1, 36, 5, 19, 22, 6, 22] \n", - "5 [12, 7, 6, 5, 26, 20, 90, 28, 132, 36, 21, 8] \n", + "0 [7, 11, 73, 6, 31, 5, 19, 63, 52, 1, 28, 19, 2... \n", + "1 [42, 22, 30, 26, 19, 9, 53, 5, 51, 5, 19, 3, 2... \n", + "2 [7, 60, 2, 7, 28, 2, 25, 24, 151, 74, 112, 31,... \n", + "4 [67, 1, 16, 31, 21, 9, 14, 3, 8, 22, 23, 50, 0... \n", + "5 [31, 17, 49, 13, 49, 16, 23, 85, 23, 164, 28, ... \n", "... ... \n", - "2110 [10, 16] \n", - "2111 [37, 28] \n", - "2112 [18, 15] \n", - "2113 [9, 116] \n", - "2114 [6, 9] \n", + "2145 [34, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2146 [1, 60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2148 [23, 29, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", + "2149 [52, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2150 [3, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "\n", " category-list \\\n", - "0 [5, 5, 2, 8, 2, 1, 2, 45, 3, 6, 8, 6, 1, 16, 3... \n", - "1 [3, 1, 1, 10, 1, 4, 6, 6, 1, 1, 8, 10, 1] \n", - "2 [4, 3, 1, 4, 6, 1, 4, 1, 1, 29, 3, 5, 2] \n", - "4 [15, 3, 2, 12, 4, 2, 7, 1, 4, 4, 1, 4] \n", - "5 [3, 2, 1, 1, 5, 2, 16, 5, 23, 7, 4, 1] \n", + "0 [2, 5, 19, 2, 9, 1, 4, 17, 13, 1, 7, 4, 8, 5, ... \n", + "1 [12, 6, 9, 7, 4, 2, 15, 1, 13, 1, 4, 1, 8, 3, ... \n", + "2 [2, 16, 1, 2, 7, 1, 8, 8, 40, 24, 29, 9, 17, 0... \n", + "4 [17, 1, 3, 9, 6, 2, 3, 1, 4, 6, 8, 14, 0, 0, 0... \n", + "5 [9, 3, 13, 5, 13, 3, 8, 23, 8, 51, 7, 2, 0, 0,... \n", "... ... \n", - "2110 [3, 3] \n", - "2111 [7, 5] \n", - "2112 [4, 3] \n", - "2113 [1, 20] \n", - "2114 [1, 1] \n", + "2145 [9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "2146 [1, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2148 [8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "2149 [13, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2150 [1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "\n", " age_days-list \\\n", - "0 [0.97853184, 0.4591664, 0.083990775, 0.7000025... \n", - "1 [0.76496226, 0.85960853, 0.13536207, 0.3988903... \n", - "2 [0.42529476, 0.66954064, 0.46188155, 0.2200255... \n", - "4 [0.7655469, 0.4924979, 0.9192873, 0.6521773, 0... \n", - "5 [0.86268437, 0.11732827, 0.31621945, 0.0408642... \n", + "0 [0.84568787, 0.038363576, 0.7171949, 0.0886422... \n", + "1 [0.4074032, 0.7792388, 0.49303588, 0.027537243... \n", + "2 [0.9137222, 0.77429664, 0.4397028, 0.41606435,... \n", + "4 [0.7679332, 0.7644972, 0.8533882, 0.67827713, ... \n", + "5 [0.32460424, 0.9527502, 0.77985513, 0.91916, 0... \n", "... ... \n", - "2110 [0.9952336, 0.018463716] \n", - "2111 [0.26565734, 0.3376144] \n", - "2112 [0.65739745, 0.46439078] \n", - "2113 [0.45008472, 0.36275008] \n", - "2114 [0.61278456, 0.64234763] \n", + "2145 [0.44386843, 0.17579898, 0.0, 0.0, 0.0, 0.0, 0... \n", + "2146 [0.45839304, 0.15023704, 0.0, 0.0, 0.0, 0.0, 0... \n", + "2148 [0.7376038, 0.7187783, 0.0, 0.0, 0.0, 0.0, 0.0... \n", + "2149 [0.96259063, 0.8100127, 0.0, 0.0, 0.0, 0.0, 0.... \n", + "2150 [0.9268296, 0.71968925, 0.0, 0.0, 0.0, 0.0, 0.... \n", "\n", " weekday_sin-list \n", - "0 [0.04896013, 0.18139902, 0.5046173, 0.48253214... \n", - "1 [0.3081522, 0.17396946, 0.8448347, 0.8297997, ... \n", - "2 [0.951742, 0.7311401, 0.6795269, 0.5283087, 0.... \n", - "4 [0.060284566, 0.9057582, 0.9853312, 0.27452144... \n", - "5 [0.8027563, 0.7638514, 0.055432655, 0.06549974... \n", + "0 [0.9072822, 0.55461484, 0.2662152, 0.6641106, ... \n", + "1 [0.65899414, 0.42423004, 0.20023833, 0.6077999... \n", + "2 [0.3428851, 0.9583178, 0.07852303, 0.8921527, ... \n", + "4 [0.87136024, 0.92441916, 0.27371496, 0.4557360... \n", + "5 [0.12728073, 0.87657094, 0.7073715, 0.9970732,... \n", "... ... \n", - "2110 [0.3855745, 0.8623388] \n", - "2111 [0.519952, 0.117240556] \n", - "2112 [0.49096248, 0.5064814] \n", - "2113 [0.10166882, 0.8127918] \n", - "2114 [0.46083343, 0.8074532] \n", + "2145 [0.58763367, 0.997146, 0.0, 0.0, 0.0, 0.0, 0.0... \n", + "2146 [0.47192892, 0.6211317, 0.0, 0.0, 0.0, 0.0, 0.... \n", + "2148 [0.4954509, 0.5675057, 0.0, 0.0, 0.0, 0.0, 0.0... \n", + "2149 [0.3484375, 0.10194607, 0.0, 0.0, 0.0, 0.0, 0.... \n", + "2150 [0.8299869, 0.7187812, 0.0, 0.0, 0.0, 0.0, 0.0... \n", "\n", - "[1686 rows x 6 columns]" + "[1718 rows x 6 columns]" ] }, "execution_count": 15, @@ -1009,7 +1007,7 @@ { "data": { "text/plain": [ - "490" + "565" ] }, "execution_count": 16, diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb index 1ac1435603..c46ab0dab2 100644 --- a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -70,8 +70,56 @@ "text": [ "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n" + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (AvgPrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (PrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" ] } ], @@ -201,8 +249,8 @@ " (categorical_module): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(503, 64, padding_idx=0)\n", - " (category-list): Embedding(126, 64, padding_idx=0)\n", + " (item_id-list): Embedding(496, 64, padding_idx=0)\n", + " (category-list): Embedding(179, 64, padding_idx=0)\n", " )\n", " )\n", " )\n", @@ -235,6 +283,7 @@ " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", " (dropout): Dropout(p=0.3, inplace=False)\n", + " (activation_function): GELUActivation()\n", " )\n", " (dropout): Dropout(p=0.3, inplace=False)\n", " )\n", @@ -248,6 +297,7 @@ " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", " (dropout): Dropout(p=0.3, inplace=False)\n", + " (activation_function): GELUActivation()\n", " )\n", " (dropout): Dropout(p=0.3, inplace=False)\n", " )\n", @@ -273,15 +323,15 @@ " (embeddings): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(503, 64, padding_idx=0)\n", - " (category-list): Embedding(126, 64, padding_idx=0)\n", + " (item_id-list): Embedding(496, 64, padding_idx=0)\n", + " (category-list): Embedding(179, 64, padding_idx=0)\n", " )\n", " )\n", - " (item_embedding_table): Embedding(503, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(496, 64, padding_idx=0)\n", " (masking): MaskedLanguageModeling()\n", " (pre): Block(\n", " (module): NextItemPredictionTask(\n", - " (item_embedding_table): Embedding(503, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(496, 64, padding_idx=0)\n", " (log_softmax): LogSoftmax(dim=-1)\n", " )\n", " )\n", @@ -386,13 +436,13 @@ { "data": { "text/plain": [ - "tensor([[27, 26, 7, ..., 32, 14, 0],\n", - " [15, 5, 5, ..., 0, 0, 0],\n", - " [17, 12, 9, ..., 0, 0, 0],\n", + "tensor([[ 7, 11, 73, ..., 0, 0, 0],\n", + " [ 42, 22, 30, ..., 0, 0, 0],\n", + " [ 7, 60, 2, ..., 0, 0, 0],\n", " ...,\n", - " [30, 13, 21, ..., 0, 0, 0],\n", - " [19, 14, 8, ..., 0, 0, 0],\n", - " [11, 27, 16, ..., 0, 0, 0]], device='cuda:0')" + " [ 18, 37, 18, ..., 0, 0, 0],\n", + " [ 12, 19, 33, ..., 0, 0, 0],\n", + " [ 11, 16, 102, ..., 0, 0, 0]], device='cuda:0')" ] }, "execution_count": 10, @@ -445,6 +495,19 @@ { "cell_type": "code", "execution_count": 13, + "id": "c215a81a-dec7-466b-aeb5-1e698f0b021f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for col_name, col_schema in input_schema.column_schemas.items():\n", + " input_schema[col_name] = input_schema[col_name].with_shape((None, sparse_max[col_name]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "id": "757cd0c5-f581-488b-a8de-b8d1188820d6", "metadata": {}, "outputs": [ @@ -476,38 +539,46 @@ "\n", + " | name | \n", + "tags | \n", + "dtype | \n", + "is_list | \n", + "is_ragged | \n", + "properties.int_domain.min | \n", + "properties.int_domain.max | \n", + "properties.triton_scalar_shape | \n", + "properties.value_count.min | \n", + "properties.value_count.max | \n", + "
---|---|---|---|---|---|---|---|---|---|---|
0 | \n", + "age_days-list | \n", + "(Tags.CONTINUOUS, Tags.LIST) | \n", + "DType(name='float32', element_type=<ElementTyp... | \n", + "True | \n", + "False | \n", + "0 | \n", + "0 | \n", + "[] | \n", + "20 | \n", + "20 | \n", + "
1 | \n", + "weekday_sin-list | \n", + "(Tags.CONTINUOUS, Tags.LIST) | \n", + "DType(name='float32', element_type=<ElementTyp... | \n", + "True | \n", + "False | \n", + "0 | \n", + "0 | \n", + "[] | \n", + "20 | \n", + "20 | \n", + "
2 | \n", + "item_id-list | \n", + "(Tags.ITEM_ID, Tags.ID, Tags.LIST, Tags.CATEGO... | \n", + "DType(name='int64', element_type=<ElementType.... | \n", + "True | \n", + "False | \n", + "0 | \n", + "495 | \n", + "[] | \n", + "20 | \n", + "20 | \n", + "
3 | \n", + "category-list | \n", + "(Tags.LIST, Tags.CATEGORICAL) | \n", + "DType(name='int64', element_type=<ElementType.... | \n", + "True | \n", + "False | \n", + "0 | \n", + "178 | \n", + "[] | \n", + "20 | \n", + "20 | \n", + "