diff --git a/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb b/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb index 20273f6694..6168597d34 100644 --- a/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb @@ -39,7 +39,7 @@ "\n", "**Launch the docker container**\n", "```\n", - "docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -p 8888:8888 -v :/workspace/data/ nvcr.io/nvidia/merlin/merlin-pytorch:22.XX\n", + "docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -p 8888:8888 -v :/workspace/data/ nvcr.io/nvidia/merlin/merlin-pytorch:23.XX\n", "```\n", "This script will mount your local data folder that includes your data files to `/workspace/data` directory in the merlin-pytorch docker container." ] @@ -210,11 +210,11 @@ "output_type": "stream", "text": [ " session_id timestamp item_id category itemid_ts_first\n", - "0 4993 1396727816 214835285 0 1396332436\n", - "1 4993 1396727863 214530703 0 1396339114\n", - "2 4993 1396727898 214530705 0 1396330224\n", - "3 4993 1396728063 214835713 0 1396327474\n", - "4 4993 1396730097 214512611 0 1396328044\n" + "0 7401 1396439960 214826816 0 1396321828\n", + "1 7402 1396780751 214613743 0 1396329089\n", + "2 7402 1396780780 214827011 0 1396735848\n", + "3 7402 1396780912 214821388 0 1396330458\n", + "4 7402 1396780991 214827011 0 1396735848\n" ] } ], @@ -272,7 +272,7 @@ { "data": { "text/plain": [ - "518" + "0" ] }, "execution_count": 10, @@ -321,7 +321,7 @@ "source": [ "In this cell, we are defining three transformations ops: \n", "\n", - "- 1. Encoding categorical variables using `Categorify()` op. We set `start_index` to 1 so that encoded null values start from `1` instead of `0` because we reserve `0` for padding the sequence features.\n", + "- 1. Encoding categorical variables using `Categorify()` op. Categorify op maps nulls to `1`, OOVs to `2`, automatically. We reserve `0` for padding the sequence features. The encoding of each category starts from 3.\n", "- 2. Deriving temporal features from timestamp and computing their cyclical representation using a custom lambda function. \n", "- 3. Computing the item recency in days using a custom op. Note that item recency is defined as the difference between the first occurrence of the item in dataset and the actual date of item interaction. \n", "\n", @@ -336,7 +336,7 @@ "outputs": [], "source": [ "# Encodes categorical features as contiguous integers\n", - "cat_feats = ColumnSelector(['category', 'item_id']) >> nvt.ops.Categorify(start_index=1)\n", + "cat_feats = ColumnSelector(['category', 'item_id']) >> nvt.ops.Categorify()\n", "\n", "# create time features\n", "session_ts = ColumnSelector(['timestamp'])\n", @@ -494,18 +494,7 @@ "execution_count": 13, "id": "45803886", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "dataset = nvt.Dataset(interactions_merged_df)\n", "workflow = nvt.Workflow(filtered_sessions)\n", @@ -559,7 +548,6 @@ " properties.num_buckets\n", " properties.freq_threshold\n", " properties.max_size\n", - " properties.start_index\n", " properties.cat_path\n", " properties.domain.min\n", " properties.domain.max\n", @@ -589,7 +577,6 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " \n", " \n", " 1\n", @@ -601,12 +588,11 @@ " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 52740.0\n", - " item_id\n", " 52741.0\n", + " item_id\n", + " 52742.0\n", " 512.0\n", " NaN\n", " NaN\n", @@ -614,19 +600,18 @@ " \n", " 2\n", " item_id-list\n", - " (Tags.CATEGORICAL, Tags.ITEM_ID, Tags.ITEM, Ta...\n", + " (Tags.CATEGORICAL, Tags.ITEM, Tags.ID, Tags.LIST)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 52740.0\n", - " item_id\n", " 52741.0\n", + " item_id\n", + " 52742.0\n", " 512.0\n", " 0.0\n", " 20.0\n", @@ -634,8 +619,8 @@ " \n", " 3\n", " et_dayofweek_sin-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", - " DType(name='float32', element_type=<ElementTyp...\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", + " DType(name='float64', element_type=<ElementTyp...\n", " True\n", " True\n", " NaN\n", @@ -647,14 +632,13 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " 0.0\n", " 20.0\n", " \n", " \n", " 4\n", " product_recency_days_log_norm-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " True\n", @@ -667,7 +651,6 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " 0.0\n", " 20.0\n", " \n", @@ -681,12 +664,11 @@ " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", " .//categories/unique.category.parquet\n", " 0.0\n", - " 335.0\n", - " category\n", " 336.0\n", + " category\n", + " 337.0\n", " 42.0\n", " 0.0\n", " 20.0\n", @@ -709,14 +691,13 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " \n", " \n", "\n", "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52740, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52741, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52740, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52741, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 335, 'name': 'category'}, 'embedding_sizes': {'cardinality': 336, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]" + "[{'name': 'session_id', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]" ] }, "execution_count": 14, @@ -792,14 +773,14 @@ "6606149 11255553 2 \n", "\n", " item_id-list \\\n", - "6606147 [604, 878, 742, 90, 4777, 1583, 3446, 8083, 34... \n", - "6606148 [184, 12288] \n", - "6606149 [7299, 1953] \n", + "6606147 [605, 879, 743, 91, 4778, 1584, 3447, 8084, 34... \n", + "6606148 [185, 12289] \n", + "6606149 [7300, 1954] \n", "\n", " et_dayofweek_sin-list \\\n", - "6606147 [-0.43388462, -0.43388462, -0.43388462, -0.433... \n", - "6606148 [-0.43388462, -0.43388462] \n", - "6606149 [-0.781831, -0.781831] \n", + "6606147 [-0.43388454782514785, -0.43388454782514785, -... \n", + "6606148 [-0.43388454782514785, -0.43388454782514785] \n", + "6606149 [-0.7818309228245777, -0.7818309228245777] \n", "\n", " product_recency_days_log_norm-list \\\n", "6606147 [1.5241553, 1.5238751, 1.5239341, 1.5241631, 1... \n", @@ -807,9 +788,9 @@ "6606149 [1.5338266, 1.5355074] \n", "\n", " category-list day_index \n", - "6606147 [3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3] 178 \n", - "6606148 [2, 2] 178 \n", - "6606149 [7, 7] 180 \n" + "6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4] 178 \n", + "6606148 [3, 3] 178 \n", + "6606149 [8, 8] 180 \n" ] } ], @@ -827,7 +808,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.76it/s]\n" + "Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.37it/s]\n" ] } ], @@ -849,7 +830,7 @@ { "data": { "text/plain": [ - "570" + "583" ] }, "execution_count": 19, diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index b0ffd0df46..2542ebe16b 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -59,7 +59,18 @@ "execution_count": 2, "id": "1e8dae24", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", @@ -178,48 +189,48 @@ " \n", " \n", " 0\n", - " 83769\n", - " 18\n", - " 5\n", - " 0.405307\n", - " 0.061487\n", - " 9\n", + " 88348\n", + " 28\n", + " 7\n", + " 0.416052\n", + " 0.116508\n", + " 1\n", " \n", " \n", " 1\n", - " 80130\n", - " 24\n", + " 86615\n", + " 6\n", + " 2\n", + " 0.998783\n", + " 0.539034\n", " 6\n", - " 0.441516\n", - " 0.436325\n", - " 7\n", " \n", " \n", " 2\n", - " 71523\n", - " 61\n", - " 15\n", - " 0.952613\n", - " 0.737854\n", - " 7\n", + " 85161\n", + " 14\n", + " 4\n", + " 0.975656\n", + " 0.246331\n", + " 3\n", " \n", " \n", " 3\n", - " 85847\n", - " 4\n", - " 1\n", - " 0.346014\n", - " 0.438771\n", - " 1\n", + " 75889\n", + " 61\n", + " 16\n", + " 0.329182\n", + " 0.033715\n", + " 9\n", " \n", " \n", " 4\n", - " 81181\n", - " 21\n", - " 5\n", - " 0.302464\n", - " 0.437222\n", - " 5\n", + " 75396\n", + " 29\n", + " 8\n", + " 0.219127\n", + " 0.993250\n", + " 7\n", " \n", " \n", "\n", @@ -227,11 +238,11 @@ ], "text/plain": [ " session_id item_id category age_days weekday_sin day\n", - "0 83769 18 5 0.405307 0.061487 9\n", - "1 80130 24 6 0.441516 0.436325 7\n", - "2 71523 61 15 0.952613 0.737854 7\n", - "3 85847 4 1 0.346014 0.438771 1\n", - "4 81181 21 5 0.302464 0.437222 5" + "0 88348 28 7 0.416052 0.116508 1\n", + "1 86615 6 2 0.998783 0.539034 6\n", + "2 85161 14 4 0.975656 0.246331 3\n", + "3 75889 61 16 0.329182 0.033715 9\n", + "4 75396 29 8 0.219127 0.993250 7" ] }, "execution_count": 6, @@ -256,7 +267,7 @@ "id": "139de226", "metadata": {}, "source": [ - "Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers `(0, ..., |C|)`, where `|C|` is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the `Categorify` op encodes OOVs or nulls to `0` automatically. In our synthetic dataset we do not have any nulls. On the other hand `0` is also used for padding the sequences in input block, therefore, you can set `start_index=1` arg in the Categorify op if you want the encoded null or OOV values to start from `1` instead of `0` because we reserve `0` for padding the sequence features." + "Deep Learning models require dense input features. Categorical features are sparse, and need to be represented by dense embeddings in the model. To allow for that, categorical features first need to be encoded as contiguous integers `(0, ..., |C|)`, where `|C|` is the feature cardinality (number of unique values), so that their embeddings can be efficiently stored in embedding layers. We will use NVTabular to preprocess the categorical features, so that all categorical columns are encoded as contiguous integers. Note that the `Categorify` op encodes `nulls` to `1`, OOVs to `2` automatically. We preserve `0` for padding. The encoding of other categories starts from `3`. In our synthetic dataset we do not have any nulls. On the other hand `0` is used for padding the sequences in input block. " ] }, { @@ -272,106 +283,7 @@ "execution_count": 7, "id": "a256f195", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "SESSIONS_MAX_LENGTH =20\n", "\n", @@ -475,7 +387,6 @@ " properties.num_buckets\n", " properties.freq_threshold\n", " properties.max_size\n", - " properties.start_index\n", " properties.cat_path\n", " properties.domain.min\n", " properties.domain.max\n", @@ -505,7 +416,6 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " \n", " \n", " 1\n", @@ -525,24 +435,22 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " \n", " \n", " 2\n", " item_id-list\n", - " (Tags.ITEM, Tags.LIST, Tags.ITEM_ID, Tags.ID, ...\n", + " (Tags.CATEGORICAL, Tags.ID, Tags.LIST, Tags.ITEM)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", " NaN\n", " 0.0\n", " 0.0\n", - " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 492.0\n", + " 494.0\n", " item_id\n", - " 493.0\n", + " 495.0\n", " 52.0\n", " 2.0\n", " 16.0\n", @@ -550,19 +458,18 @@ " \n", " 3\n", " category-list\n", - " (Tags.LIST, Tags.CATEGORICAL)\n", + " (Tags.CATEGORICAL, Tags.LIST)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", " NaN\n", " 0.0\n", " 0.0\n", - " 0.0\n", " .//categories/unique.category.parquet\n", " 0.0\n", - " 172.0\n", + " 171.0\n", " category\n", - " 173.0\n", + " 172.0\n", " 29.0\n", " 2.0\n", " 16.0\n", @@ -570,7 +477,7 @@ " \n", " 4\n", " age_days-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " True\n", @@ -583,14 +490,13 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " 2.0\n", " 16.0\n", " \n", " \n", " 5\n", " weekday_sin-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " True\n", @@ -603,7 +509,6 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " 2.0\n", " 16.0\n", " \n", @@ -612,7 +517,7 @@ "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 492, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 493, 'dimension': 52}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 172, 'name': 'category'}, 'embedding_sizes': {'cardinality': 173, 'dimension': 29}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}]" + "[{'name': 'session_id', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 494, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 495, 'dimension': 52}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 171, 'name': 'category'}, 'embedding_sizes': {'cardinality': 172, 'dimension': 29}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}]" ] }, "execution_count": 8, @@ -691,20 +596,20 @@ "name": "stdout", "output_type": "stream", "text": [ - " session_id day-first item_id-list category-list \\\n", - "0 70000 4 [41, 174, 53, 6, 1, 1] [10, 45, 13, 2, 1, 1] \n", - "1 70001 7 [61, 6, 8, 19, 27, 1] [15, 2, 2, 5, 7, 1] \n", - "2 70002 9 [3, 186, 7] [1, 45, 2] \n", + " session_id day-first item_id-list category-list \\\n", + "0 70000 1 [306, 5, 40, 17] [104, 3, 12, 6] \n", + "1 70001 1 [43, 20, 69, 8, 57] [13, 6, 21, 3, 16] \n", + "2 70002 1 [137, 35, 37, 85, 65, 5] [37, 10, 11, 22, 18, 3] \n", "\n", " age_days-list \\\n", - "0 [0.7460891, 0.428004, 0.5986437, 0.06489895, 0... \n", - "1 [0.5229818, 0.7588001, 0.52459955, 0.21908909,... \n", - "2 [0.050518222, 0.123651706, 0.7635223] \n", + "0 [0.044022594, 0.34956282, 0.7326993, 0.09403495] \n", + "1 [0.8072543, 0.28916782, 0.04966254, 0.08417622... \n", + "2 [0.04696693, 0.94499177, 0.2922437, 0.83047426... \n", "\n", " weekday_sin-list \n", - "0 [0.8084707, 0.28762853, 0.6447227, 0.23568074,... \n", - "1 [0.34757933, 0.71422356, 0.5048536, 0.91580784... \n", - "2 [0.0969222, 0.007312222, 0.5712969] \n" + "0 [0.7417527, 0.60325843, 0.07417604, 0.28911334] \n", + "1 [0.7995051, 0.86722755, 0.84298295, 0.15793765... \n", + "2 [0.72519076, 0.92308444, 0.40120387, 0.3821016... \n" ] } ], @@ -722,7 +627,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|██████████| 9/9 [00:00<00:00, 34.13it/s]\n" + "Creating time-based splits: 100%|██████████| 9/9 [00:02<00:00, 4.12it/s]\n" ] } ], @@ -790,69 +695,69 @@ " \n", " \n", " 0\n", - " 70005\n", - " [18, 44, 10, 28]\n", - " [5, 11, 2, 7]\n", - " [0.68638474, 0.5562314, 0.15063263, 0.15271373]\n", - " [0.9033511, 0.61792773, 0.2624225, 0.6792286]\n", + " 70000\n", + " [306, 5, 40, 17]\n", + " [104, 3, 12, 6]\n", + " [0.044022594, 0.34956282, 0.7326993, 0.09403495]\n", + " [0.7417527, 0.60325843, 0.07417604, 0.28911334]\n", " \n", " \n", " 1\n", - " 70021\n", - " [123, 9, 165, 58]\n", - " [32, 3, 40, 15]\n", - " [0.845705, 0.49234316, 0.06037465, 0.49099287]\n", - " [0.9355467, 0.6995572, 0.5586675, 0.7762437]\n", + " 70001\n", + " [43, 20, 69, 8, 57]\n", + " [13, 6, 21, 3, 16]\n", + " [0.8072543, 0.28916782, 0.04966254, 0.08417622...\n", + " [0.7995051, 0.86722755, 0.84298295, 0.15793765...\n", " \n", " \n", " 2\n", - " 70024\n", - " [17, 7, 85, 2]\n", - " [5, 2, 21, 1]\n", - " [0.10014895, 0.8815941, 0.120751254, 0.6055496]\n", - " [0.9650407, 0.31927255, 0.64654016, 0.09534722]\n", + " 70002\n", + " [137, 35, 37, 85, 65, 5]\n", + " [37, 10, 11, 22, 18, 3]\n", + " [0.04696693, 0.94499177, 0.2922437, 0.83047426...\n", + " [0.72519076, 0.92308444, 0.40120387, 0.3821016...\n", " \n", " \n", " 4\n", - " 70056\n", - " [4, 5, 65, 22, 23, 9, 51]\n", - " [1, 3, 16, 6, 6, 3, 13]\n", - " [0.0072424724, 0.47864556, 0.011090885, 0.7643...\n", - " [0.23196527, 0.30617988, 0.308681, 0.4716874, ...\n", + " 70007\n", + " [28, 9, 153, 74, 53, 15, 173]\n", + " [9, 4, 39, 20, 15, 5, 46]\n", + " [0.4730765, 0.69885534, 0.034774363, 0.7225920...\n", + " [0.33613566, 0.660022, 0.72897774, 0.66087157,...\n", " \n", " \n", " 5\n", - " 70070\n", - " [2, 8, 11, 82]\n", - " [1, 2, 4, 20]\n", - " [0.73769826, 0.68819356, 0.34728694, 0.8315548]\n", - " [0.1711698, 0.44255763, 0.6116058, 0.9729463]\n", + " 70021\n", + " [59, 32, 11, 21, 23, 23, 9, 15]\n", + " [17, 10, 7, 7, 8, 8, 4, 5]\n", + " [0.07898139, 0.27463168, 0.1885847, 0.5203435,...\n", + " [0.39734098, 0.74895114, 0.43540764, 0.8372503...\n", " \n", " \n", "\n", "" ], "text/plain": [ - " session_id item_id-list category-list \\\n", - "0 70005 [18, 44, 10, 28] [5, 11, 2, 7] \n", - "1 70021 [123, 9, 165, 58] [32, 3, 40, 15] \n", - "2 70024 [17, 7, 85, 2] [5, 2, 21, 1] \n", - "4 70056 [4, 5, 65, 22, 23, 9, 51] [1, 3, 16, 6, 6, 3, 13] \n", - "5 70070 [2, 8, 11, 82] [1, 2, 4, 20] \n", + " session_id item_id-list category-list \\\n", + "0 70000 [306, 5, 40, 17] [104, 3, 12, 6] \n", + "1 70001 [43, 20, 69, 8, 57] [13, 6, 21, 3, 16] \n", + "2 70002 [137, 35, 37, 85, 65, 5] [37, 10, 11, 22, 18, 3] \n", + "4 70007 [28, 9, 153, 74, 53, 15, 173] [9, 4, 39, 20, 15, 5, 46] \n", + "5 70021 [59, 32, 11, 21, 23, 23, 9, 15] [17, 10, 7, 7, 8, 8, 4, 5] \n", "\n", " age_days-list \\\n", - "0 [0.68638474, 0.5562314, 0.15063263, 0.15271373] \n", - "1 [0.845705, 0.49234316, 0.06037465, 0.49099287] \n", - "2 [0.10014895, 0.8815941, 0.120751254, 0.6055496] \n", - "4 [0.0072424724, 0.47864556, 0.011090885, 0.7643... \n", - "5 [0.73769826, 0.68819356, 0.34728694, 0.8315548] \n", + "0 [0.044022594, 0.34956282, 0.7326993, 0.09403495] \n", + "1 [0.8072543, 0.28916782, 0.04966254, 0.08417622... \n", + "2 [0.04696693, 0.94499177, 0.2922437, 0.83047426... \n", + "4 [0.4730765, 0.69885534, 0.034774363, 0.7225920... \n", + "5 [0.07898139, 0.27463168, 0.1885847, 0.5203435,... \n", "\n", " weekday_sin-list \n", - "0 [0.9033511, 0.61792773, 0.2624225, 0.6792286] \n", - "1 [0.9355467, 0.6995572, 0.5586675, 0.7762437] \n", - "2 [0.9650407, 0.31927255, 0.64654016, 0.09534722] \n", - "4 [0.23196527, 0.30617988, 0.308681, 0.4716874, ... \n", - "5 [0.1711698, 0.44255763, 0.6116058, 0.9729463] " + "0 [0.7417527, 0.60325843, 0.07417604, 0.28911334] \n", + "1 [0.7995051, 0.86722755, 0.84298295, 0.15793765... \n", + "2 [0.72519076, 0.92308444, 0.40120387, 0.3821016... \n", + "4 [0.33613566, 0.660022, 0.72897774, 0.66087157,... \n", + "5 [0.39734098, 0.74895114, 0.43540764, 0.8372503... " ] }, "execution_count": 15, @@ -874,7 +779,7 @@ { "data": { "text/plain": [ - "555" + "512" ] }, "execution_count": 16, @@ -913,7 +818,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb index 7682a5e3af..516d82b58c 100644 --- a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -63,7 +63,18 @@ "execution_count": 2, "id": "3ba89970", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n" + ] + } + ], "source": [ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", @@ -126,16 +137,7 @@ "execution_count": 4, "id": "9d1299fa", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "from merlin.schema import Schema\n", "from merlin.io import Dataset\n", @@ -204,8 +206,8 @@ " (categorical_module): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(493, 64, padding_idx=0)\n", - " (category-list): Embedding(173, 64, padding_idx=0)\n", + " (item_id-list): Embedding(495, 64, padding_idx=0)\n", + " (category-list): Embedding(172, 64, padding_idx=0)\n", " )\n", " )\n", " )\n", @@ -238,7 +240,6 @@ " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", " (dropout): Dropout(p=0.3, inplace=False)\n", - " (activation_function): GELUActivation()\n", " )\n", " (dropout): Dropout(p=0.3, inplace=False)\n", " )\n", @@ -252,7 +253,6 @@ " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", " (dropout): Dropout(p=0.3, inplace=False)\n", - " (activation_function): GELUActivation()\n", " )\n", " (dropout): Dropout(p=0.3, inplace=False)\n", " )\n", @@ -278,15 +278,15 @@ " (embeddings): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(493, 64, padding_idx=0)\n", - " (category-list): Embedding(173, 64, padding_idx=0)\n", + " (item_id-list): Embedding(495, 64, padding_idx=0)\n", + " (category-list): Embedding(172, 64, padding_idx=0)\n", " )\n", " )\n", - " (item_embedding_table): Embedding(493, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(495, 64, padding_idx=0)\n", " (masking): MaskedLanguageModeling()\n", " (pre): Block(\n", " (module): NextItemPredictionTask(\n", - " (item_embedding_table): Embedding(493, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(495, 64, padding_idx=0)\n", " )\n", " )\n", " )\n", @@ -364,208 +364,37 @@ { "data": { "text/plain": [ - "{'weekday_sin-list__values': tensor([9.0335e-01, 6.1793e-01, 2.6242e-01, 6.7923e-01, 9.3555e-01, 6.9956e-01,\n", - " 5.5867e-01, 7.7624e-01, 9.6504e-01, 3.1927e-01, 6.4654e-01, 9.5347e-02,\n", - " 2.3197e-01, 3.0618e-01, 3.0868e-01, 4.7169e-01, 8.9577e-01, 4.9156e-01,\n", - " 5.5645e-01, 1.7117e-01, 4.4256e-01, 6.1161e-01, 9.7295e-01, 5.4066e-01,\n", - " 1.1771e-01, 4.5218e-01, 4.0567e-01, 7.1686e-01, 5.5469e-02, 8.5642e-01,\n", - " 1.9757e-02, 1.5123e-01, 7.6822e-01, 4.7810e-01, 7.3114e-01, 5.4558e-01,\n", - " 5.5872e-01, 9.6191e-01, 6.5765e-04, 7.8162e-01, 7.0617e-01, 9.2865e-01,\n", - " 1.3238e-02, 2.1268e-01, 3.3873e-01, 1.3442e-01, 1.8885e-01, 6.8171e-01,\n", - " 5.7528e-01, 3.3820e-01, 3.1279e-01, 1.0624e-01, 5.7980e-01, 3.6283e-01,\n", - " 6.9245e-01, 4.0857e-01, 1.8025e-01, 3.7041e-01, 7.1286e-02, 4.7412e-03,\n", - " 9.4665e-01, 7.3423e-01, 6.9356e-01, 3.0290e-01, 6.0237e-01, 8.6221e-02,\n", - " 3.0249e-01, 4.8158e-01, 4.5226e-01, 7.5549e-01, 7.2938e-01, 9.8018e-01,\n", - " 4.4025e-01, 2.9200e-01, 4.0991e-02, 7.2172e-01, 8.3618e-01, 2.9709e-03,\n", - " 1.7382e-01, 9.4742e-01, 5.2274e-01, 7.7680e-03, 8.5222e-01, 5.3768e-01,\n", - " 9.1014e-01, 8.3261e-01, 7.1640e-01, 3.5953e-01, 6.8745e-01, 7.3665e-01,\n", - " 3.0701e-01, 8.8899e-01, 1.6659e-01, 9.1659e-01, 4.6888e-01, 9.1468e-03,\n", - " 5.7153e-01, 7.8624e-01, 5.6548e-01, 5.7832e-01, 7.7408e-02, 3.8156e-01,\n", - " 4.4821e-01, 5.0542e-01, 8.2571e-01, 8.4099e-01, 8.3429e-01, 4.5229e-01,\n", - " 7.4253e-01, 1.2133e-01, 3.0753e-01, 4.7609e-01, 9.3145e-01, 3.1146e-01,\n", - " 4.7979e-01, 3.1741e-01, 7.0556e-02, 3.7274e-03, 3.3100e-01, 7.4556e-01,\n", - " 1.9668e-01, 2.7893e-01, 3.7026e-01, 5.2113e-01, 8.1771e-01, 7.9684e-01,\n", - " 4.3612e-01, 8.9390e-01, 4.6388e-01, 9.3233e-02, 3.0521e-01, 2.9624e-01,\n", - " 2.0111e-01, 5.3116e-01, 6.3834e-01, 8.1599e-01, 3.9889e-01, 2.0726e-01,\n", - " 5.0752e-01, 6.3629e-01, 1.1615e-01, 6.0907e-01, 9.3705e-01, 7.6081e-01,\n", - " 2.2344e-01, 3.7341e-01, 2.1352e-01, 5.5792e-01, 2.3884e-01, 5.1241e-01,\n", - " 4.2859e-01, 2.6745e-01, 7.0149e-01, 7.7271e-01, 9.7470e-01, 2.9851e-01,\n", - " 3.6220e-01, 7.1964e-01, 9.8924e-02, 4.2278e-01, 9.2394e-01, 5.4416e-02,\n", - " 8.2690e-01, 9.5137e-01, 2.7069e-01, 4.0910e-01, 6.5487e-01, 3.3828e-01,\n", - " 3.5586e-01, 7.8433e-01, 9.5182e-01, 1.0664e-01, 4.5966e-01, 9.1269e-01,\n", - " 9.0110e-01, 8.8323e-01, 8.0163e-01, 6.1198e-01, 7.1632e-01, 8.0810e-01,\n", - " 3.6022e-01, 1.0125e-01, 6.1275e-02, 1.9713e-01, 6.9611e-01, 1.0539e-01,\n", - " 6.3606e-01, 1.7859e-01, 4.9180e-01, 6.0091e-01, 1.9489e-01, 3.1242e-01,\n", - " 5.7435e-02, 7.6455e-01, 9.1532e-03, 1.4022e-01, 6.7814e-01, 6.4885e-01,\n", - " 2.1707e-01, 8.7882e-02, 7.3319e-01, 2.5464e-01, 6.3461e-01, 9.2501e-01,\n", - " 5.0877e-01, 4.6345e-01, 9.6097e-01, 3.0540e-01, 7.2994e-01, 1.8059e-02,\n", - " 2.0526e-01, 4.7532e-01, 6.2443e-01, 5.3653e-01, 1.8451e-01, 3.4582e-01,\n", - " 3.1663e-02, 2.9339e-01, 4.8247e-01, 5.9224e-01, 8.5249e-01, 4.3782e-01,\n", - " 5.0847e-01, 7.4306e-01, 4.0228e-01, 3.5060e-01, 2.0383e-01, 5.9909e-01,\n", - " 1.9630e-01, 3.0325e-01, 5.8152e-01, 1.9598e-04, 1.5798e-01, 8.6199e-01,\n", - " 6.0801e-01, 1.8470e-01, 2.4106e-01, 3.9562e-01, 4.0336e-01, 5.0500e-01,\n", - " 9.8991e-01, 7.7172e-01, 2.0854e-01, 7.1236e-01, 8.5714e-01, 5.9961e-01,\n", - " 6.2871e-01, 9.5358e-01, 9.8993e-01, 2.7197e-02, 4.0978e-01, 7.4517e-01,\n", - " 4.9870e-01, 2.7043e-01, 2.4044e-01, 5.2013e-01, 4.0094e-01, 4.4082e-01,\n", - " 2.1521e-01, 5.7735e-01, 5.5698e-01, 6.5140e-01, 2.1639e-01, 3.9069e-01,\n", - " 1.7496e-01, 9.7648e-02, 3.0500e-02, 3.1027e-01, 3.1238e-01, 7.6859e-02,\n", - " 3.7755e-01, 4.5391e-01, 3.3556e-01, 2.1391e-03, 5.3818e-02, 2.1444e-01,\n", - " 7.0766e-01, 6.2140e-01, 9.6876e-01, 3.0493e-01, 1.1580e-01, 5.9688e-01,\n", - " 1.0951e-01, 6.7726e-01, 5.5295e-01, 2.8579e-01, 1.2490e-02, 1.3221e-03,\n", - " 5.5287e-01, 4.9942e-01, 2.6729e-01, 4.1500e-02, 8.1352e-01, 4.3543e-02,\n", - " 9.8552e-01, 4.6436e-02, 4.1681e-01, 9.0265e-01, 5.7555e-02, 1.9550e-01,\n", - " 6.5123e-01, 1.3511e-01, 1.9624e-01, 4.9690e-01, 1.2664e-01, 8.2656e-01,\n", - " 2.4217e-01, 7.8323e-01, 1.6209e-02, 4.2552e-01, 4.7377e-01, 1.4360e-01,\n", - " 9.0079e-01, 3.5225e-01, 2.5649e-01, 3.7778e-01, 8.2888e-02, 8.9731e-01,\n", - " 6.6882e-01, 3.0262e-01, 8.1284e-01, 9.9743e-01, 8.0287e-01, 6.2663e-02,\n", - " 2.6941e-01, 3.2377e-01, 8.6484e-02, 1.7002e-01, 6.0658e-01, 1.8581e-01,\n", - " 7.4517e-01, 5.9501e-01, 2.6587e-02, 1.4658e-01, 9.7115e-01, 7.2902e-01,\n", - " 3.9843e-01, 2.5587e-01, 8.4932e-02, 8.9251e-01, 8.3415e-01, 1.6210e-01,\n", - " 8.3666e-01, 7.6193e-01, 1.0099e-01, 5.6911e-01, 1.1346e-01, 6.8467e-01,\n", - " 4.4927e-01, 2.6970e-01, 1.3098e-01, 2.9867e-01, 3.7987e-01, 3.8652e-01,\n", - " 4.5246e-01, 4.4367e-01, 5.2977e-01, 5.2168e-02, 4.8019e-01, 6.2810e-01,\n", - " 2.1713e-01, 4.3246e-01, 8.5420e-01, 4.2350e-01, 1.7373e-01, 8.3713e-01,\n", - " 8.0916e-01, 1.9364e-01, 1.1451e-01, 8.3049e-01, 5.5018e-02, 8.4588e-01,\n", - " 9.8551e-02, 6.9381e-01, 3.8779e-02, 6.7309e-01, 2.2248e-01, 8.5116e-01,\n", - " 3.8011e-01, 8.7532e-01, 5.8754e-01, 9.6269e-01, 2.4285e-01, 3.2505e-01,\n", - " 3.3101e-01, 9.8015e-01, 3.4053e-01, 8.5623e-01, 5.4515e-01, 4.3592e-01,\n", - " 4.7118e-01, 7.5477e-01, 4.2002e-01, 4.7548e-01, 1.7103e-01, 7.6246e-02,\n", - " 4.3196e-01, 8.3998e-01, 5.9510e-01, 2.8226e-01, 4.3411e-01, 8.5313e-01,\n", - " 3.3239e-01, 8.4910e-02, 3.3828e-01, 1.9970e-01, 2.0644e-01, 5.2219e-01,\n", - " 5.7864e-01, 2.2277e-01, 2.7035e-01, 8.8196e-01, 1.9876e-01, 7.1749e-01,\n", - " 6.2589e-01, 5.3938e-01, 9.7966e-01, 3.4293e-01, 2.5501e-01, 7.6974e-01,\n", - " 3.5402e-02, 4.7173e-01], device='cuda:0'),\n", - " 'weekday_sin-list__offsets': tensor([ 0, 4, 8, 12, 19, 23, 28, 30, 36, 45, 51, 57, 63, 65,\n", - " 68, 71, 77, 83, 90, 97, 101, 104, 109, 113, 119, 126, 134, 137,\n", - " 147, 154, 158, 162, 170, 174, 180, 184, 190, 196, 202, 206, 211, 217,\n", - " 220, 225, 234, 237, 242, 249, 255, 259, 266, 271, 273, 277, 280, 284,\n", - " 294, 298, 301, 305, 307, 311, 315, 322, 328, 332, 338, 343, 346, 353,\n", - " 356, 360, 367, 370, 374, 386, 391, 394, 399, 405, 412, 419, 422],\n", - " device='cuda:0', dtype=torch.int32),\n", - " 'age_days-list__values': tensor([0.6864, 0.5562, 0.1506, 0.1527, 0.8457, 0.4923, 0.0604, 0.4910, 0.1001,\n", - " 0.8816, 0.1208, 0.6055, 0.0072, 0.4786, 0.0111, 0.7643, 0.2377, 0.9877,\n", - " 0.5471, 0.7377, 0.6882, 0.3473, 0.8316, 0.3003, 0.0339, 0.3124, 0.2429,\n", - " 0.0916, 0.8580, 0.7283, 0.2579, 0.0203, 0.5521, 0.5429, 0.9092, 0.9706,\n", - " 0.5143, 0.3932, 0.5602, 0.2905, 0.3103, 0.7574, 0.6345, 0.8460, 0.1654,\n", - " 0.6438, 0.8112, 0.8559, 0.7671, 0.7051, 0.0342, 0.8202, 0.6450, 0.6239,\n", - " 0.6535, 0.5372, 0.7945, 0.9880, 0.3430, 0.1261, 0.5409, 0.9111, 0.7708,\n", - " 0.3961, 0.1971, 0.1810, 0.9822, 0.0135, 0.5818, 0.4049, 0.5955, 0.0175,\n", - " 0.0582, 0.6451, 0.5682, 0.1767, 0.2046, 0.0181, 0.9892, 0.2091, 0.8784,\n", - " 0.7990, 0.4760, 0.9837, 0.0483, 0.4145, 0.0501, 0.2644, 0.6876, 0.1772,\n", - " 0.9530, 0.6617, 0.7595, 0.6418, 0.7053, 0.7353, 0.1178, 0.5708, 0.9118,\n", - " 0.3817, 0.8163, 0.5592, 0.0263, 0.7426, 0.7504, 0.1453, 0.1261, 0.2843,\n", - " 0.0916, 0.8025, 0.5057, 0.1798, 0.4978, 0.4059, 0.3467, 0.8321, 0.9569,\n", - " 0.2257, 0.4661, 0.5355, 0.5287, 0.2789, 0.6137, 0.7507, 0.2671, 0.3475,\n", - " 0.6775, 0.2755, 0.2705, 0.6361, 0.7418, 0.7470, 0.6971, 0.3926, 0.6781,\n", - " 0.4782, 0.7376, 0.5079, 0.9950, 0.9877, 0.0786, 0.8750, 0.6036, 0.9578,\n", - " 0.3204, 0.4169, 0.2733, 0.8141, 0.6977, 0.5282, 0.8436, 0.4513, 0.0044,\n", - " 0.8754, 0.3406, 0.2862, 0.2795, 0.6000, 0.5966, 0.5941, 0.0756, 0.7670,\n", - " 0.8045, 0.1671, 0.6139, 0.4253, 0.0875, 0.9646, 0.2904, 0.9187, 0.0631,\n", - " 0.0440, 0.8061, 0.8814, 0.9460, 0.3258, 0.1310, 0.4768, 0.1844, 0.9852,\n", - " 0.3005, 0.5041, 0.6340, 0.8558, 0.5530, 0.5584, 0.8220, 0.2970, 0.6035,\n", - " 0.4829, 0.6203, 0.5038, 0.2977, 0.4476, 0.6181, 0.5745, 0.2515, 0.7483,\n", - " 0.4560, 0.2936, 0.8943, 0.4444, 0.7550, 0.8761, 0.9139, 0.8925, 0.9655,\n", - " 0.2475, 0.5479, 0.3591, 0.2017, 0.7192, 0.5854, 0.0280, 0.3394, 0.0548,\n", - " 0.8865, 0.0885, 0.8830, 0.9059, 0.5144, 0.6704, 0.2530, 0.9066, 0.1493,\n", - " 0.0133, 0.3411, 0.5805, 0.3219, 0.8381, 0.1898, 0.9853, 0.4951, 0.8921,\n", - " 0.4684, 0.5548, 0.0959, 0.2271, 0.8040, 0.8180, 0.2429, 0.4656, 0.0248,\n", - " 0.6433, 0.2792, 0.1939, 0.2916, 0.7245, 0.1623, 0.9434, 0.5893, 0.0988,\n", - " 0.4240, 0.7009, 0.0702, 0.5703, 0.0785, 0.3646, 0.8654, 0.5999, 0.8501,\n", - " 0.6637, 0.6934, 0.7100, 0.4033, 0.7082, 0.8922, 0.2551, 0.8100, 0.6159,\n", - " 0.0352, 0.9296, 0.4019, 0.0560, 0.0140, 0.2568, 0.8935, 0.2220, 0.6011,\n", - " 0.5311, 0.6508, 0.9499, 0.2882, 0.3859, 0.7192, 0.5734, 0.5430, 0.5470,\n", - " 0.6100, 0.3221, 0.8932, 0.1917, 0.2715, 0.9648, 0.3337, 0.3423, 0.1694,\n", - " 0.8997, 0.1235, 0.0754, 0.1066, 0.7731, 0.9779, 0.1149, 0.7129, 0.5492,\n", - " 0.5105, 0.5903, 0.1652, 0.8601, 0.3179, 0.4858, 0.5632, 0.2972, 0.3266,\n", - " 0.3010, 0.9428, 0.3308, 0.4054, 0.9851, 0.8445, 0.0965, 0.9255, 0.6904,\n", - " 0.4425, 0.3705, 0.6532, 0.6187, 0.5300, 0.6050, 0.0195, 0.8343, 0.0675,\n", - " 0.2041, 0.3092, 0.6547, 0.6616, 0.0972, 0.4775, 0.3849, 0.3033, 0.9609,\n", - " 0.9839, 0.4001, 0.6065, 0.2203, 0.5405, 0.0463, 0.4115, 0.6639, 0.6231,\n", - " 0.1286, 0.3898, 0.3187, 0.7574, 0.8418, 0.5627, 0.3724, 0.7124, 0.4849,\n", - " 0.6358, 0.1927, 0.9656, 0.2104, 0.3219, 0.0530, 0.7592, 0.4185, 0.8550,\n", - " 0.4254, 0.9376, 0.0722, 0.4513, 0.4236, 0.9076, 0.9693, 0.6590, 0.0147,\n", - " 0.2602, 0.1906, 0.1603, 0.9987, 0.1635, 0.9510, 0.7504, 0.8241, 0.3808,\n", - " 0.4338, 0.3709, 0.9349, 0.5595, 0.8686, 0.4290, 0.6465, 0.5390, 0.1859,\n", - " 0.6526, 0.4768, 0.5636, 0.2489, 0.4109, 0.9884, 0.7743, 0.2705, 0.2039,\n", - " 0.7362, 0.0805, 0.8489, 0.2946, 0.9948, 0.7444, 0.2934, 0.8453, 0.4200,\n", - " 0.2433, 0.1293, 0.9999, 0.2679, 0.7504, 0.5701, 0.4573, 0.9727],\n", - " device='cuda:0'),\n", - " 'age_days-list__offsets': tensor([ 0, 4, 8, 12, 19, 23, 28, 30, 36, 45, 51, 57, 63, 65,\n", - " 68, 71, 77, 83, 90, 97, 101, 104, 109, 113, 119, 126, 134, 137,\n", - " 147, 154, 158, 162, 170, 174, 180, 184, 190, 196, 202, 206, 211, 217,\n", - " 220, 225, 234, 237, 242, 249, 255, 259, 266, 271, 273, 277, 280, 284,\n", - " 294, 298, 301, 305, 307, 311, 315, 322, 328, 332, 338, 343, 346, 353,\n", - " 356, 360, 367, 370, 374, 386, 391, 394, 399, 405, 412, 419, 422],\n", - " device='cuda:0', dtype=torch.int32),\n", - " 'item_id-list__values': tensor([ 18, 44, 10, 28, 123, 9, 165, 58, 17, 7, 85, 2, 4, 5,\n", - " 65, 22, 23, 9, 51, 2, 8, 11, 82, 5, 2, 111, 43, 11,\n", - " 12, 33, 1, 29, 4, 5, 28, 14, 12, 2, 84, 15, 4, 16,\n", - " 6, 5, 2, 89, 19, 31, 12, 42, 22, 13, 12, 21, 48, 4,\n", - " 10, 9, 60, 55, 36, 39, 7, 11, 19, 15, 19, 28, 48, 7,\n", - " 184, 10, 24, 23, 41, 2, 2, 34, 10, 35, 8, 49, 51, 11,\n", - " 188, 40, 101, 101, 83, 10, 6, 1, 7, 72, 7, 7, 17, 85,\n", - " 24, 2, 40, 20, 2, 2, 35, 54, 17, 22, 38, 15, 28, 37,\n", - " 3, 11, 8, 13, 22, 9, 17, 2, 18, 32, 6, 162, 13, 7,\n", - " 7, 51, 66, 30, 104, 58, 10, 349, 2, 17, 3, 18, 12, 21,\n", - " 40, 193, 11, 23, 51, 8, 8, 20, 7, 65, 23, 8, 42, 67,\n", - " 9, 89, 36, 17, 66, 5, 124, 12, 17, 9, 10, 5, 44, 28,\n", - " 14, 87, 40, 5, 87, 10, 8, 11, 27, 3, 8, 16, 64, 12,\n", - " 15, 187, 10, 19, 7, 63, 27, 66, 11, 98, 9, 30, 7, 22,\n", - " 13, 15, 14, 9, 10, 22, 3, 4, 17, 62, 31, 25, 5, 4,\n", - " 20, 10, 85, 3, 83, 6, 19, 31, 15, 16, 6, 9, 4, 100,\n", - " 14, 36, 191, 16, 8, 155, 33, 21, 8, 18, 15, 119, 1, 22,\n", - " 27, 15, 94, 141, 6, 9, 4, 16, 16, 6, 3, 4, 29, 20,\n", - " 2, 2, 27, 17, 4, 49, 17, 112, 8, 67, 2, 26, 6, 35,\n", - " 1, 44, 36, 76, 35, 14, 51, 47, 18, 1, 16, 5, 12, 55,\n", - " 34, 32, 114, 33, 39, 20, 43, 6, 8, 7, 56, 7, 90, 95,\n", - " 18, 3, 59, 20, 3, 2, 5, 20, 120, 4, 137, 2, 21, 7,\n", - " 33, 7, 7, 35, 51, 39, 20, 75, 27, 1, 24, 24, 15, 61,\n", - " 14, 10, 31, 50, 8, 57, 19, 46, 59, 25, 18, 67, 23, 42,\n", - " 38, 2, 3, 21, 25, 97, 2, 7, 79, 32, 73, 11, 26, 24,\n", - " 81, 2, 4, 29, 4, 7, 5, 46, 67, 65, 26, 26, 10, 23,\n", - " 35, 33, 3, 41, 157, 64, 19, 14, 6, 21, 45, 18, 31, 7,\n", - " 10, 147, 14, 12, 37, 3, 13, 13, 7, 35, 13, 7, 22, 5,\n", - " 82, 10, 36, 5, 55, 20, 8, 148, 62, 16, 1, 40, 35, 227,\n", - " 69, 2, 59, 73, 47, 16, 58, 7, 39, 12, 43, 30, 16, 9,\n", - " 23, 164], device='cuda:0'),\n", - " 'item_id-list__offsets': tensor([ 0, 4, 8, 12, 19, 23, 28, 30, 36, 45, 51, 57, 63, 65,\n", - " 68, 71, 77, 83, 90, 97, 101, 104, 109, 113, 119, 126, 134, 137,\n", - " 147, 154, 158, 162, 170, 174, 180, 184, 190, 196, 202, 206, 211, 217,\n", - " 220, 225, 234, 237, 242, 249, 255, 259, 266, 271, 273, 277, 280, 284,\n", - " 294, 298, 301, 305, 307, 311, 315, 322, 328, 332, 338, 343, 346, 353,\n", - " 356, 360, 367, 370, 374, 386, 391, 394, 399, 405, 412, 419, 422],\n", - " device='cuda:0', dtype=torch.int32),\n", - " 'category-list__values': tensor([ 5, 11, 2, 7, 32, 3, 40, 15, 5, 2, 21, 1, 1, 3, 16, 6, 6, 3,\n", - " 13, 1, 2, 4, 20, 3, 1, 28, 11, 4, 4, 3, 1, 8, 1, 3, 7, 3,\n", - " 4, 1, 21, 4, 1, 5, 2, 3, 1, 22, 5, 8, 4, 11, 6, 4, 4, 3,\n", - " 12, 1, 2, 3, 15, 14, 9, 10, 2, 4, 5, 4, 5, 7, 12, 2, 52, 2,\n", - " 6, 6, 10, 1, 1, 9, 2, 9, 2, 12, 13, 4, 53, 10, 26, 26, 21, 2,\n", - " 2, 1, 2, 17, 2, 2, 5, 21, 6, 1, 10, 6, 1, 1, 9, 14, 5, 6,\n", - " 10, 4, 7, 9, 1, 4, 2, 4, 6, 3, 5, 1, 5, 8, 2, 43, 4, 2,\n", - " 2, 13, 17, 8, 27, 15, 2, 84, 1, 5, 1, 5, 4, 3, 10, 45, 4, 6,\n", - " 13, 2, 2, 6, 2, 16, 6, 2, 11, 17, 3, 22, 9, 5, 17, 3, 32, 4,\n", - " 5, 3, 2, 3, 11, 7, 3, 22, 10, 3, 22, 2, 2, 4, 7, 1, 2, 5,\n", - " 16, 4, 4, 49, 2, 5, 2, 16, 7, 17, 4, 24, 3, 8, 2, 6, 4, 4,\n", - " 3, 3, 2, 6, 1, 1, 5, 16, 8, 7, 3, 1, 6, 2, 21, 1, 21, 2,\n", - " 5, 8, 4, 5, 2, 3, 1, 26, 3, 9, 44, 5, 2, 41, 3, 3, 2, 5,\n", - " 4, 30, 1, 6, 7, 4, 23, 36, 2, 3, 1, 5, 5, 2, 1, 1, 8, 6,\n", - " 1, 1, 7, 5, 1, 12, 5, 29, 2, 17, 1, 7, 2, 9, 1, 11, 9, 20,\n", - " 9, 3, 13, 12, 5, 1, 5, 3, 4, 14, 9, 8, 27, 3, 10, 6, 11, 2,\n", - " 2, 2, 14, 2, 23, 23, 5, 1, 15, 6, 1, 1, 3, 6, 29, 1, 31, 1,\n", - " 3, 2, 3, 2, 2, 9, 13, 10, 6, 19, 7, 1, 6, 6, 4, 15, 3, 2,\n", - " 8, 13, 2, 14, 5, 12, 15, 7, 5, 17, 6, 11, 10, 1, 1, 3, 7, 24,\n", - " 1, 2, 19, 8, 18, 4, 7, 6, 21, 1, 1, 8, 1, 2, 3, 12, 17, 16,\n", - " 7, 7, 2, 6, 9, 3, 1, 10, 40, 16, 5, 3, 2, 3, 11, 5, 8, 2,\n", - " 2, 37, 3, 4, 9, 1, 4, 4, 2, 9, 4, 2, 6, 3, 20, 2, 9, 3,\n", - " 14, 6, 2, 39, 16, 5, 1, 10, 9, 54, 18, 1, 15, 18, 12, 5, 15, 2,\n", - " 10, 4, 11, 8, 5, 3, 6, 38], device='cuda:0'),\n", - " 'category-list__offsets': tensor([ 0, 4, 8, 12, 19, 23, 28, 30, 36, 45, 51, 57, 63, 65,\n", - " 68, 71, 77, 83, 90, 97, 101, 104, 109, 113, 119, 126, 134, 137,\n", - " 147, 154, 158, 162, 170, 174, 180, 184, 190, 196, 202, 206, 211, 217,\n", - " 220, 225, 234, 237, 242, 249, 255, 259, 266, 271, 273, 277, 280, 284,\n", - " 294, 298, 301, 305, 307, 311, 315, 322, 328, 332, 338, 343, 346, 353,\n", - " 356, 360, 367, 370, 374, 386, 391, 394, 399, 405, 412, 419, 422],\n", - " device='cuda:0', dtype=torch.int32)}" + "tensor([306, 5, 40, 17, 43, 20, 69, 8, 57, 137, 35, 37, 85, 65,\n", + " 5, 28, 9, 153, 74, 53, 15, 173, 59, 32, 11, 21, 23, 23,\n", + " 9, 15, 12, 69, 37, 16, 6, 22, 39, 20, 22, 95, 40, 7,\n", + " 25, 32, 17, 8, 26, 32, 33, 18, 12, 10, 41, 14, 28, 56,\n", + " 30, 21, 16, 42, 13, 83, 65, 46, 105, 38, 11, 3, 3, 14,\n", + " 9, 36, 116, 15, 15, 23, 8, 16, 68, 151, 60, 18, 48, 19,\n", + " 16, 4, 37, 246, 169, 21, 16, 116, 27, 4, 19, 76, 6, 31,\n", + " 153, 38, 35, 11, 38, 3, 73, 38, 74, 6, 7, 12, 18, 10,\n", + " 54, 11, 29, 5, 24, 11, 20, 3, 17, 42, 26, 24, 30, 26,\n", + " 62, 89, 12, 38, 18, 3, 10, 18, 15, 131, 19, 6, 51, 60,\n", + " 10, 3, 14, 22, 21, 39, 44, 221, 88, 14, 16, 80, 5, 16,\n", + " 21, 81, 27, 8, 20, 49, 32, 83, 49, 19, 3, 17, 8, 10,\n", + " 29, 62, 94, 38, 15, 11, 12, 16, 10, 31, 7, 53, 3, 42,\n", + " 38, 25, 5, 62, 20, 73, 48, 6, 12, 19, 15, 38, 30, 9,\n", + " 82, 31, 49, 64, 22, 38, 10, 56, 11, 13, 3, 14, 39, 18,\n", + " 47, 65, 18, 15, 9, 74, 3, 50, 37, 22, 66, 47, 23, 17,\n", + " 8, 21, 35, 7, 12, 16, 21, 26, 31, 13, 20, 9, 193, 49,\n", + " 9, 62, 51, 45, 90, 14, 47, 9, 73, 16, 3, 62, 24, 82,\n", + " 7, 14, 37, 29, 26, 42, 6, 90, 3, 10, 33, 7, 7, 10,\n", + " 31, 10, 12, 21, 55, 25, 21, 3, 20, 24, 25, 4, 3, 52,\n", + " 5, 5, 10, 12, 37, 162, 31, 5, 119, 5, 24, 65, 4, 10,\n", + " 46, 86, 5, 58, 15, 48, 66, 14, 23, 12, 13, 6, 48, 8,\n", + " 22, 95, 5, 42, 86, 108, 26, 7, 80, 54, 63, 12, 147, 177,\n", + " 17, 18, 24, 15, 40, 5, 40, 7, 6, 63, 4, 18, 123, 33,\n", + " 36, 25, 40, 18, 16, 10, 18, 26, 21, 59, 44, 12, 28, 30,\n", + " 134, 7, 21, 8, 7, 32, 41, 60, 52, 25, 36, 6, 45, 39,\n", + " 16, 20, 95, 8, 56, 53, 48, 17, 14, 3, 46, 35, 17, 12,\n", + " 30, 8, 5, 54, 75, 96, 4, 43, 8, 61, 4, 8, 34, 30,\n", + " 34, 49, 29, 92, 6, 28, 26, 22, 46, 20, 11, 14, 13, 75,\n", + " 22, 21, 17, 166, 4, 87, 5, 11, 37, 26, 23],\n", + " device='cuda:0')" ] }, "execution_count": 9, @@ -574,7 +403,7 @@ } ], "source": [ - "model_input_dict" + "model_input_dict['item_id-list__values']" ] }, { @@ -642,13 +471,12 @@ " is_ragged\n", " properties.value_count.min\n", " properties.value_count.max\n", - " properties.embedding_sizes.dimension\n", - " properties.embedding_sizes.cardinality\n", " properties.num_buckets\n", - " properties.start_index\n", + " properties.freq_threshold\n", " properties.max_size\n", " properties.cat_path\n", - " properties.freq_threshold\n", + " properties.embedding_sizes.cardinality\n", + " properties.embedding_sizes.dimension\n", " properties.domain.min\n", " properties.domain.max\n", " properties.domain.name\n", @@ -658,7 +486,7 @@ " \n", " 0\n", " weekday_sin-list\n", - " (Tags.CONTINUOUS, Tags.LIST)\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " True\n", @@ -673,12 +501,11 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " \n", " \n", " 1\n", " age_days-list\n", - " (Tags.CONTINUOUS, Tags.LIST)\n", + " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " True\n", @@ -693,26 +520,24 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " \n", " \n", " 2\n", " item_id-list\n", - " (Tags.ITEM, Tags.ID, Tags.ITEM_ID, Tags.CATEGO...\n", + " (Tags.CATEGORICAL, Tags.LIST, Tags.ITEM, Tags.ID)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", " 2\n", " 16\n", - " 52.0\n", - " 493.0\n", " NaN\n", " 0.0\n", " 0.0\n", " .//categories/unique.item_id.parquet\n", + " 495.0\n", + " 52.0\n", " 0.0\n", - " 0.0\n", - " 492.0\n", + " 494.0\n", " item_id\n", " \n", " \n", @@ -724,15 +549,14 @@ " True\n", " 2\n", " 16\n", - " 29.0\n", - " 173.0\n", " NaN\n", " 0.0\n", " 0.0\n", " .//categories/unique.category.parquet\n", - " 0.0\n", - " 0.0\n", " 172.0\n", + " 29.0\n", + " 0.0\n", + " 171.0\n", " category\n", " \n", " \n", @@ -740,7 +564,7 @@ "" ], "text/plain": [ - "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'embedding_sizes': {'dimension': 52.0, 'cardinality': 493.0}, 'num_buckets': None, 'start_index': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'freq_threshold': 0.0, 'domain': {'min': 0, 'max': 492, 'name': 'item_id'}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'cat_path': './/categories/unique.category.parquet', 'max_size': 0.0, 'num_buckets': None, 'start_index': 0.0, 'freq_threshold': 0.0, 'embedding_sizes': {'dimension': 29.0, 'cardinality': 173.0}, 'domain': {'min': 0, 'max': 172, 'name': 'category'}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}]" + "[{'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 495.0, 'dimension': 52.0}, 'domain': {'min': 0, 'max': 494, 'name': 'item_id'}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 172.0, 'dimension': 29.0}, 'domain': {'min': 0, 'max': 171, 'name': 'category'}, 'value_count': {'min': 2, 'max': 16}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=16)))), 'is_list': True, 'is_ragged': True}]" ] }, "execution_count": 12, @@ -821,11 +645,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/workspace/merlin/core/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n", - "/workspace/merlin/systems/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'session_id', which is not being used by any downstream operator in the ensemble graph.\n", + "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'session_id', which is not being used by any downstream operator in the ensemble graph.\n", " warnings.warn(\n", - "/workspace/merlin/systems/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'day-first', which is not being used by any downstream operator in the ensemble graph.\n", + "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'day-first', which is not being used by any downstream operator in the ensemble graph.\n", " warnings.warn(\n" ] } @@ -1092,8 +914,8 @@ "output_type": "stream", "text": [ " session_id item_id category age_days weekday_sin day\n", - "0 81119 7 2 0.844459 0.395856 1\n", - "1 70544 8 3 0.994260 0.483951 8\n" + "0 79856 3 2 0.327276 0.080060 2\n", + "1 74117 6 4 0.012172 0.147716 1\n" ] } ], @@ -1135,19 +957,19 @@ { "data": { "text/plain": [ - "{'next-item': array([[-4.0163584, 2.8880801, 2.8579702, ..., -3.2957313, -3.4589174,\n", - " -3.1256526],\n", - " [-4.0165143, 2.887963 , 2.8577764, ..., -3.296179 , -3.4597163,\n", - " -3.1257925],\n", - " [-4.014146 , 2.8889873, 2.8591576, ..., -3.2919312, -3.4524093,\n", - " -3.1247814],\n", + "{'next-item': array([[-3.9399953, -2.632081 , -4.2211075, ..., -3.6699016, -3.673493 ,\n", + " -3.1244578],\n", + " [-3.940445 , -2.6335964, -4.2203593, ..., -3.671566 , -3.6745713,\n", + " -3.1240335],\n", + " [-3.9393594, -2.6300201, -4.222065 , ..., -3.6674871, -3.672068 ,\n", + " -3.1251097],\n", " ...,\n", - " [-4.017418 , 2.8872206, 2.85729 , ..., -3.2978077, -3.462584 ,\n", - " -3.1261778],\n", - " [-4.0187573, 2.8866076, 2.8563938, ..., -3.2999353, -3.466495 ,\n", - " -3.126523 ],\n", - " [-4.0182676, 2.8868668, 2.8567138, ..., -3.2992334, -3.4651473,\n", - " -3.1263661]], dtype=float32)}" + " [-3.9396427, -2.6304667, -4.2218847, ..., -3.6677885, -3.6724825,\n", + " -3.1250875],\n", + " [-3.939829 , -2.6316376, -4.221267 , ..., -3.6693997, -3.6732295,\n", + " -3.1245873],\n", + " [-3.9399223, -2.631995 , -4.2210817, ..., -3.669589 , -3.6734715,\n", + " -3.1244512]], dtype=float32)}" ] }, "execution_count": 21, @@ -1172,7 +994,7 @@ { "data": { "text/plain": [ - "(31, 493)" + "(28, 495)" ] }, "execution_count": 22, @@ -1214,7 +1036,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.8.10" } }, "nbformat": 4, diff --git a/examples/tutorial/02-ETL-with-NVTabular.ipynb b/examples/tutorial/02-ETL-with-NVTabular.ipynb index b12fa95ebd..1cf626965c 100644 --- a/examples/tutorial/02-ETL-with-NVTabular.ipynb +++ b/examples/tutorial/02-ETL-with-NVTabular.ipynb @@ -278,7 +278,7 @@ "tags": [] }, "source": [ - "We see that `'category_code'` and `'brand'` columns have null values, and in the following cell we are going to fill these nulls with via categorify op, and then all categorical columns will be encoded to continuous integers. Note that we add `start_index=1` in the `Categorify op` for the categorical columns, the reason for that we want the encoded null values to start from `1` instead of `0` because we reserve `0` for padding the sequence features." + "We see that `'category_code'` and `'brand'` columns have null values, and in the following cell we are going to fill these nulls with via categorify op, and then all categorical columns will be encoded to continuous integers. Categorify op maps nulls to `1`, OOVs to `2`, automatically. We reserve `0` for padding the sequence features. The encoding of each category starts from `3`." ] }, { @@ -300,7 +300,7 @@ "source": [ "# categorify features \n", "item_id = ['product_id'] >> nvt.ops.TagAsItemID()\n", - "cat_feats = item_id + ['category_code', 'brand', 'user_id', 'category_id', 'event_type'] >> nvt.ops.Categorify(start_index=1)" + "cat_feats = item_id + ['category_code', 'brand', 'user_id', 'category_id', 'event_type'] >> nvt.ops.Categorify()" ] }, { @@ -698,303 +698,9 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
nametagsdtypeis_listis_raggedproperties.num_bucketsproperties.freq_thresholdproperties.max_sizeproperties.start_indexproperties.cat_pathproperties.domain.minproperties.domain.maxproperties.domain.nameproperties.embedding_sizes.cardinalityproperties.embedding_sizes.dimensionproperties.value_count.minproperties.value_count.max
0user_session(Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....FalseFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
1product_id-count(Tags.ITEM_ID, Tags.ITEM, Tags.CATEGORICAL, Ta...DType(name='int32', element_type=<ElementType....FalseFalseNaN0.00.01.0.//categories/unique.product_id.parquet0.0118334.0product_id118335.0512.0NaNNaN
2product_id-list(Tags.ID, Tags.CATEGORICAL, Tags.LIST, Tags.IT...DType(name='int64', element_type=<ElementType....TrueFalseNaN0.00.01.0.//categories/unique.product_id.parquet0.0118334.0product_id118335.0512.020.020.0
3category_code-list(Tags.CATEGORICAL, Tags.LIST)DType(name='int64', element_type=<ElementType....TrueFalseNaN0.00.01.0.//categories/unique.category_code.parquet0.0124.0category_code125.024.020.020.0
4brand-list(Tags.CATEGORICAL, Tags.LIST)DType(name='int64', element_type=<ElementType....TrueFalseNaN0.00.01.0.//categories/unique.brand.parquet0.02640.0brand2641.0132.020.020.0
5category_id-list(Tags.CATEGORICAL, Tags.LIST)DType(name='int64', element_type=<ElementType....TrueFalseNaN0.00.01.0.//categories/unique.category_id.parquet0.0566.0category_id567.056.020.020.0
6et_dayofweek_sin-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float32', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
7et_dayofweek_cos-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float32', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
8price_log_norm-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float32', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
9relative_price_to_avg_categ_id-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float64', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
10product_recency_days_log_norm-list(Tags.LIST, Tags.CONTINUOUS)DType(name='float32', element_type=<ElementTyp...TrueFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN20.020.0
11day_index(Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....FalseFalseNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
\n", - "
" - ], - "text/plain": [ - "[{'name': 'user_session', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'product_id-count', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.product_id.parquet', 'domain': {'min': 0, 'max': 118334, 'name': 'product_id'}, 'embedding_sizes': {'cardinality': 118335, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'product_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.product_id.parquet', 'domain': {'min': 0, 'max': 118334, 'name': 'product_id'}, 'embedding_sizes': {'cardinality': 118335, 'dimension': 512}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category_code-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.category_code.parquet', 'domain': {'min': 0, 'max': 124, 'name': 'category_code'}, 'embedding_sizes': {'cardinality': 125, 'dimension': 24}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'brand-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.brand.parquet', 'domain': {'min': 0, 'max': 2640, 'name': 'brand'}, 'embedding_sizes': {'cardinality': 2641, 'dimension': 132}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category_id-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.category_id.parquet', 'domain': {'min': 0, 'max': 566, 'name': 'category_id'}, 'embedding_sizes': {'cardinality': 567, 'dimension': 56}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'et_dayofweek_cos-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'price_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'relative_price_to_avg_categ_id-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'day_index', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=None)), 'is_list': False, 'is_ragged': False}]" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "workflow.output_schema" ] diff --git a/examples/tutorial/03-Session-based-recsys.ipynb b/examples/tutorial/03-Session-based-recsys.ipynb index 261e47555d..266be9e31f 100644 --- a/examples/tutorial/03-Session-based-recsys.ipynb +++ b/examples/tutorial/03-Session-based-recsys.ipynb @@ -2330,120 +2330,6 @@ " f.write('%s:%s\\n' % (key, value.item()))" ] }, - { - "cell_type": "markdown", - "id": "dfd8172f", - "metadata": {}, - "source": [ - "After model training and evaluation is completed we can save our trained model in the next section. " - ] - }, - { - "cell_type": "markdown", - "id": "e721d830", - "metadata": {}, - "source": [ - "##### Exporting the preprocessing workflow and model for deployment to Triton server" - ] - }, - { - "cell_type": "markdown", - "id": "2aa9eb95", - "metadata": {}, - "source": [ - "Load the preproc workflow that we saved in the ETL notebook." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "3dd3aa6a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import nvtabular as nvt\n", - "\n", - "# define data path about where to get our data\n", - "INPUT_DATA_DIR = os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data/\")\n", - "workflow_path = os.path.join(INPUT_DATA_DIR, 'workflow_etl')\n", - "workflow = nvt.Workflow.load(workflow_path)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6e626132", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'product_id-list': 20,\n", - " 'category_id-list': 20,\n", - " 'brand-list': 20,\n", - " 'product_recency_days_log_norm-list': 20,\n", - " 'et_dayofweek_sin-list': 20,\n", - " 'et_dayofweek_cos-list': 20,\n", - " 'price_log_norm-list': 20,\n", - " 'relative_price_to_avg_categ_id-list': 20,\n", - " 'category_code-list': 20}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# dictionary representing max sequence length for the sequential (list) columns\n", - "sparse_features_max = {\n", - " fname: sequence_length\n", - " for fname in x_cat_names + x_cont_names + ['category_code-list']\n", - "}\n", - "\n", - "sparse_features_max" - ] - }, - { - "cell_type": "markdown", - "id": "120c5740", - "metadata": {}, - "source": [ - "It is time to export the proc workflow and model in the format required by Triton Inference Server, by using the NVTabular’s `export_pytorch_ensemble()` function." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aee05de6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from nvtabular.inference.triton import export_pytorch_ensemble\n", - "export_pytorch_ensemble(\n", - " model,\n", - " workflow,\n", - " sparse_max=sparse_features_max,\n", - " name= \"t4r_pytorch\",\n", - " model_path= os.path.join(INPUT_DATA_DIR, 'models'),\n", - " label_columns =[],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "43ee6473", - "metadata": {}, - "source": [ - "Before we move on to the next notebook, `04-Inference-with-Triton`, let's print out our results.txt file. " - ] - }, { "cell_type": "code", "execution_count": 13, @@ -2504,14 +2390,6 @@ "Congratulations on finishing this notebook. In this tutorial, we have presented Transformers4Rec, an open source library designed to enable RecSys researchers and practitioners to quickly and easily explore the latest developments of the NLP for sequential and session-based recommendation tasks." ] }, - { - "cell_type": "markdown", - "id": "045d54e7", - "metadata": {}, - "source": [ - "Please shut down the kernel before moving on to the next notebook, `04-Inference-with-Triton.ipynb`." - ] - }, { "cell_type": "markdown", "id": "744a2f17", diff --git a/examples/tutorial/04-Inference-with-Triton.ipynb b/examples/tutorial/04-Inference-with-Triton.ipynb deleted file mode 100644 index 69d90aa016..0000000000 --- a/examples/tutorial/04-Inference-with-Triton.ipynb +++ /dev/null @@ -1,688 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "94181761", - "metadata": {}, - "outputs": [], - "source": [ - "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", - "#\n", - "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# http://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License.\n", - "# ==============================================================================\n", - "\n", - "# Each user is responsible for checking the content of datasets and the\n", - "# applicable licenses and determining if suitable for the intended use." - ] - }, - { - "cell_type": "markdown", - "id": "77d36393", - "metadata": {}, - "source": [ - "\n", - "\n", - "# Triton for Recommender Systems" - ] - }, - { - "cell_type": "markdown", - "id": "172f9e15", - "metadata": {}, - "source": [ - "NVIDIA [Triton Inference Server (TIS)](https://github.com/triton-inference-server/server) simplifies the deployment of AI models at scale in production. The Triton Inference Server allows us to deploy and serve our model for inference. It supports a number of different machine learning frameworks such as TensorFlow and PyTorch.\n", - "\n", - "The last step of machine learning (ML)/deep learning (DL) pipeline is to deploy the ETL workflow and saved model to production. In the production setting, we want to transform the input data as done during training (ETL). We need to apply the same mean/std for continuous features and use the same categorical mapping to convert the categories to continuous integer before we use the DL model for a prediction. Therefore, we deploy the NVTabular workflow with the PyTorch model as an ensemble model to Triton Inference. The ensemble model guarantees that the same transformation is applied to the raw inputs." - ] - }, - { - "cell_type": "markdown", - "id": "ae645daa", - "metadata": {}, - "source": [ - "![](_images/torch_triton.png)" - ] - }, - { - "cell_type": "markdown", - "id": "6f85f45d", - "metadata": {}, - "source": [ - "**Objectives:**\n", - "\n", - "Learn how to deploy a model to Triton\n", - "1. Deploy saved NVTabular and PyTorch models to Triton Inference Server\n", - "2. Sent requests for predictions" - ] - }, - { - "cell_type": "markdown", - "id": "43dc14a8", - "metadata": {}, - "source": [ - "## Pull and start Inference docker container" - ] - }, - { - "cell_type": "markdown", - "id": "f22667d0", - "metadata": {}, - "source": [ - "At this point, we start the Triton Inference Server (TIS) and then load the exported ensemble `t4r_pytorch` to the inference server. You can start triton server with the command below. Note that, you need to provide correct path of the models folder.\n", - "\n", - "```\n", - "tritonserver --model-repository= --model-control-mode=explicit\n", - "```\n", - "The model-repository path for our example is `/workspace/models`. The models haven't been loaded, yet. Below, we will request the Triton server to load the saved ensemble model." - ] - }, - { - "cell_type": "markdown", - "id": "07499907", - "metadata": {}, - "source": [ - "## 1. Deploy PyTorch and NVTabular Model to Triton Inference Server" - ] - }, - { - "cell_type": "markdown", - "id": "6b61ed1a", - "metadata": {}, - "source": [ - "Our Triton server has already been launched and is ready to make requests. Remember we already exported the saved PyTorch model in the previous notebook, and generated the config files for Triton Inference Server." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "6645e40e", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Import dependencies\n", - "import os\n", - "from time import time\n", - "\n", - "import numpy as np\n", - "import sys\n", - "import cudf" - ] - }, - { - "cell_type": "markdown", - "id": "72c90e93", - "metadata": {}, - "source": [ - "## 1.2 Review exported files" - ] - }, - { - "cell_type": "markdown", - "id": "6b8b7a4c", - "metadata": {}, - "source": [ - "Triton expects a specific directory structure for our models as the following format:" - ] - }, - { - "cell_type": "markdown", - "id": "d34dcb28", - "metadata": {}, - "source": [ - "```\n", - "/\n", - "[config.pbtxt]\n", - "/\n", - " [model.savedmodel]/\n", - " /\n", - " ...\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "9d7d3156", - "metadata": {}, - "source": [ - "Let's check out our model repository layout. You can install tree library with `apt-get install tree`, and then run `!tree /workspace/models/` to print out the model repository layout as below:\n", - "\n", - "```\n", - "├── t4r_pytorch\n", - "│ ├── 1\n", - "│ └── config.pbtxt\n", - "├── t4r_pytorch_nvt\n", - "│ ├── 1\n", - "│ │ ├── model.py\n", - "│ │ ├── __pycache__\n", - "│ │ │ └── model.cpython-38.pyc\n", - "│ │ └── workflow\n", - "│ │ ├── categories\n", - "│ │ │ ├── cat_stats.category_id.parquet\n", - "│ │ │ ├── unique.brand.parquet\n", - "│ │ │ ├── unique.category_code.parquet\n", - "│ │ │ ├── unique.category_id.parquet\n", - "│ │ │ ├── unique.event_type.parquet\n", - "│ │ │ ├── unique.product_id.parquet\n", - "│ │ │ ├── unique.user_id.parquet\n", - "│ │ │ └── unique.user_session.parquet\n", - "│ │ ├── metadata.json\n", - "│ │ └── workflow.pkl\n", - "│ └── config.pbtxt\n", - "└── t4r_pytorch_pt\n", - " ├── 1\n", - " │ ├── model_info.json\n", - " │ ├── model.pkl\n", - " │ ├── model.pth\n", - " │ ├── model.py\n", - " │ └── __pycache__\n", - " │ └── model.cpython-38.pyc\n", - " └── config.pbtxt\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "79b1036b", - "metadata": {}, - "source": [ - "Triton needs a [config file](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md) to understand how to interpret the model. Let's look at the generated config file. It defines the input columns with datatype and dimensions and the output layer. Manually creating this config file can be complicated and NVTabular generates it with the `export_pytorch_ensemble()` function, which we used in the previous notebook.\n", - "\n", - "The [config file](https://github.com/triton-inference-server/server/blob/main/docs/model_configuration.md) needs the following information:\n", - "* `name`: The name of our model. Must be the same name as the parent folder.\n", - "* `platform`: The type of framework serving the model.\n", - "* `input`: The input our model expects.\n", - " * `name`: Should correspond with the model input name.\n", - " * `data_type`: Should correspond to the input's data type.\n", - " * `dims`: The dimensions of the *request* for the input. For models that support input and output tensors with variable-size dimensions, those dimensions can be listed as -1 in the input and output configuration.\n", - "* `output`: The output parameters of our model.\n", - " * `name`: Should correspond with the model output name.\n", - " * `data_type`: Should correspond to the output's data type.\n", - " * `dims`: The dimensions of the output." - ] - }, - { - "cell_type": "markdown", - "id": "0adbbbf2", - "metadata": {}, - "source": [ - "## 1.3. Loading Model" - ] - }, - { - "cell_type": "markdown", - "id": "276a0e31", - "metadata": {}, - "source": [ - "Next, let's build a client to connect to our server. The `InferenceServerClient` object is what we'll be using to talk to Triton." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "b1e8ac0f", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "client created.\n", - "GET /v2/health/live, headers None\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/tritonhttpclient/__init__.py:31: DeprecationWarning: The package `tritonhttpclient` is deprecated and will be removed in a future version. Please use instead `tritonclient.http`\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import tritonhttpclient\n", - "\n", - "try:\n", - " triton_client = tritonhttpclient.InferenceServerClient(url=\"localhost:8000\", verbose=True)\n", - " print(\"client created.\")\n", - "except Exception as e:\n", - " print(\"channel creation failed: \" + str(e))\n", - "triton_client.is_server_live()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "f61231d8", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "POST /v2/repository/index, headers None\n", - "\n", - "\n", - "bytearray(b'[{\"name\":\"t4r_pytorch\",\"version\":\"1\",\"state\":\"READY\"},{\"name\":\"t4r_pytorch_nvt\",\"version\":\"1\",\"state\":\"READY\"},{\"name\":\"t4r_pytorch_pt\",\"version\":\"1\",\"state\":\"READY\"}]')\n" - ] - }, - { - "data": { - "text/plain": [ - "[{'name': 't4r_pytorch', 'version': '1', 'state': 'READY'},\n", - " {'name': 't4r_pytorch_nvt', 'version': '1', 'state': 'READY'},\n", - " {'name': 't4r_pytorch_pt', 'version': '1', 'state': 'READY'}]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "triton_client.get_model_repository_index()" - ] - }, - { - "cell_type": "markdown", - "id": "3d091905", - "metadata": {}, - "source": [ - "We load the ensemble model" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "260d063d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "model_name = \"t4r_pytorch\"\n", - "#triton_client.load_model(model_name=model_name)" - ] - }, - { - "cell_type": "markdown", - "id": "26345f7d", - "metadata": {}, - "source": [ - "If all models are loaded successfully, you should be seeing successfully loaded status next to each model name on your terminal." - ] - }, - { - "cell_type": "markdown", - "id": "fe1debc7", - "metadata": {}, - "source": [ - "## 2. Sent Requests for Predictions" - ] - }, - { - "cell_type": "markdown", - "id": "2b2cc71a", - "metadata": {}, - "source": [ - "Load raw data for inference: We select the first 50 interactions and filter out sessions with less than 2 interactions. For this tutorial, just as an example we use the `Oct-2019` dataset that we used for model training." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "5309a22e", - "metadata": {}, - "outputs": [], - "source": [ - "INPUT_DATA_DIR = os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data/\")\n", - "df= cudf.read_parquet(os.path.join(INPUT_DATA_DIR, 'Oct-2019.parquet'))\n", - "df=df.sort_values('event_time_ts')\n", - "batch = df.iloc[:50,:]" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "592aad96", - "metadata": {}, - "outputs": [], - "source": [ - "sessions_to_use = batch.user_session.value_counts()\n", - "filtered_batch = batch[batch.user_session.isin(sessions_to_use[sessions_to_use.values>1].index.values)]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "b860b31c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
user_sessionevent_typeproduct_idcategory_idcategory_codebrandpriceuser_idevent_time_tsprod_first_event_time_ts
35629141637332view13070672053013558920217191computers.notebooklenovo251.7455005085415698880011569888001
51733284202155view10042372053013555631882655electronics.smartphoneapple1081.9853587121715698880041569888004
37412611808164view14806132053013561092866779computers.desktoppulser908.6251274288015698880051569888005
49969373794756view315000532053013558031024687<NA>luminarc41.1655097883515698880081569888008
55892595470852view287190742053013565480109009apparel.shoes.kedsbaden102.7152057193215698880101569888010
\n", - "
" - ], - "text/plain": [ - " user_session event_type product_id category_id \\\n", - "3562914 1637332 view 1307067 2053013558920217191 \n", - "5173328 4202155 view 1004237 2053013555631882655 \n", - "3741261 1808164 view 1480613 2053013561092866779 \n", - "4996937 3794756 view 31500053 2053013558031024687 \n", - "5589259 5470852 view 28719074 2053013565480109009 \n", - "\n", - " category_code brand price user_id event_time_ts \\\n", - "3562914 computers.notebook lenovo 251.74 550050854 1569888001 \n", - "5173328 electronics.smartphone apple 1081.98 535871217 1569888004 \n", - "3741261 computers.desktop pulser 908.62 512742880 1569888005 \n", - "4996937 luminarc 41.16 550978835 1569888008 \n", - "5589259 apparel.shoes.keds baden 102.71 520571932 1569888010 \n", - "\n", - " prod_first_event_time_ts \n", - "3562914 1569888001 \n", - "5173328 1569888004 \n", - "3741261 1569888005 \n", - "4996937 1569888008 \n", - "5589259 1569888010 " - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "filtered_batch.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "fbdb4f72", - "metadata": {}, - "outputs": [], - "source": [ - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "b40c3922", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "output :\n", - " [[-12.723562 -12.491022 -10.335574 ... -12.581782 -13.05106\n", - " -12.618895 ]\n", - " [-23.642227 -22.039886 -6.7999535 ... -19.80544 -24.779701\n", - " -22.18703 ]\n", - " [-19.41877 -20.361322 -8.806894 ... -18.347404 -22.84088\n", - " -18.40315 ]\n", - " [-28.79064 -28.92058 -4.05371 ... -27.96975 -33.81291\n", - " -28.3871 ]\n", - " [-25.966614 -25.668694 -3.7074547 ... -23.999676 -29.794075\n", - " -25.53212 ]\n", - " [-19.15293 -18.776417 -7.6022983 ... -18.301687 -20.212255\n", - " -18.365705 ]]\n" - ] - } - ], - "source": [ - "import nvtabular.inference.triton as nvt_triton\n", - "import tritonclient.grpc as grpcclient\n", - "\n", - "inputs = nvt_triton.convert_df_to_triton_input(filtered_batch.columns, filtered_batch, grpcclient.InferInput)\n", - "\n", - "output_names = [\"output\"]\n", - "\n", - "outputs = []\n", - "for col in output_names:\n", - " outputs.append(grpcclient.InferRequestedOutput(col))\n", - " \n", - "MODEL_NAME_NVT = \"t4r_pytorch\"\n", - "\n", - "with grpcclient.InferenceServerClient(\"localhost:8001\") as client:\n", - " response = client.infer(MODEL_NAME_NVT, inputs)\n", - " print(col, ':\\n', response.as_numpy(col))" - ] - }, - { - "cell_type": "markdown", - "id": "2f2e07dc", - "metadata": {}, - "source": [ - "#### Visualise top-k predictions" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "45c64075", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "- Top-5 predictions for session `1167651`: 389 || 344 || 307 || 429 || 976\n", - "\n", - "- Top-5 predictions for session `1637332`: 380 || 225 || 516 || 502 || 381\n", - "\n", - "- Top-5 predictions for session `1808164`: 441 || 99 || 34 || 112 || 104\n", - "\n", - "- Top-5 predictions for session `3794756`: 398 || 239 || 2 || 361 || 5\n", - "\n", - "- Top-5 predictions for session `4202155`: 5 || 43 || 26 || 29 || 10\n", - "\n", - "- Top-5 predictions for session `5470852`: 398 || 288 || 146 || 54 || 75\n", - "\n" - ] - } - ], - "source": [ - "from transformers4rec.torch.utils.examples_utils import visualize_response\n", - "visualize_response(filtered_batch, response, top_k=5, session_col='user_session')" - ] - }, - { - "cell_type": "markdown", - "id": "1619777f", - "metadata": {}, - "source": [ - "As you see we first got prediction results (logits) from the trained model head, and then by using a handy util function `visualize_response` we extracted top-k encoded item-ids from logits. Basically, we generated recommended items for a given session.\n", - "\n", - "This is the end of the tutorial. You successfully ...\n", - "1. performed feature engineering with NVTabular\n", - "2. trained transformer architecture based session-based recommendation models with Transformers4Rec \n", - "3. deployed a trained model to Triton Inference Server, sent request and got responses from the server." - ] - }, - { - "cell_type": "markdown", - "id": "6224a7fe", - "metadata": {}, - "source": [ - "### Unload models and shut down the kernel" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f481d47f", - "metadata": {}, - "outputs": [], - "source": [ - "triton_client.unload_model(model_name=\"t4r_pytorch\")\n", - "triton_client.unload_model(model_name=\"t4r_pytorch_nvt\")\n", - "triton_client.unload_model(model_name=\"t4r_pytorch_pt\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0ae4dee2", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import IPython\n", - "app = IPython.Application.instance()\n", - "app.kernel.do_shutdown(True)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}