Skip to content

Commit

Permalink
Remove start_index from Categorify op (#714)
Browse files Browse the repository at this point in the history
* remove start_index from Categorify

* fix docker image version
  • Loading branch information
rnyak authored Jun 6, 2023
1 parent 63be2bf commit 2eb6da1
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 1,625 deletions.
83 changes: 32 additions & 51 deletions examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"\n",
"**Launch the docker container**\n",
"```\n",
"docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -p 8888:8888 -v <path_to_data>:/workspace/data/ nvcr.io/nvidia/merlin/merlin-pytorch:22.XX\n",
"docker run -it --gpus device=0 -p 8000:8000 -p 8001:8001 -p 8002:8002 -p 8888:8888 -v <path_to_data>:/workspace/data/ nvcr.io/nvidia/merlin/merlin-pytorch:23.XX\n",
"```\n",
"This script will mount your local data folder that includes your data files to `/workspace/data` directory in the merlin-pytorch docker container."
]
Expand Down Expand Up @@ -210,11 +210,11 @@
"output_type": "stream",
"text": [
" session_id timestamp item_id category itemid_ts_first\n",
"0 4993 1396727816 214835285 0 1396332436\n",
"1 4993 1396727863 214530703 0 1396339114\n",
"2 4993 1396727898 214530705 0 1396330224\n",
"3 4993 1396728063 214835713 0 1396327474\n",
"4 4993 1396730097 214512611 0 1396328044\n"
"0 7401 1396439960 214826816 0 1396321828\n",
"1 7402 1396780751 214613743 0 1396329089\n",
"2 7402 1396780780 214827011 0 1396735848\n",
"3 7402 1396780912 214821388 0 1396330458\n",
"4 7402 1396780991 214827011 0 1396735848\n"
]
}
],
Expand Down Expand Up @@ -272,7 +272,7 @@
{
"data": {
"text/plain": [
"518"
"0"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -321,7 +321,7 @@
"source": [
"In this cell, we are defining three transformations ops: \n",
"\n",
"- 1. Encoding categorical variables using `Categorify()` op. We set `start_index` to 1 so that encoded null values start from `1` instead of `0` because we reserve `0` for padding the sequence features.\n",
"- 1. Encoding categorical variables using `Categorify()` op. Categorify op maps nulls to `1`, OOVs to `2`, automatically. We reserve `0` for padding the sequence features. The encoding of each category starts from 3.\n",
"- 2. Deriving temporal features from timestamp and computing their cyclical representation using a custom lambda function. \n",
"- 3. Computing the item recency in days using a custom op. Note that item recency is defined as the difference between the first occurrence of the item in dataset and the actual date of item interaction. \n",
"\n",
Expand All @@ -336,7 +336,7 @@
"outputs": [],
"source": [
"# Encodes categorical features as contiguous integers\n",
"cat_feats = ColumnSelector(['category', 'item_id']) >> nvt.ops.Categorify(start_index=1)\n",
"cat_feats = ColumnSelector(['category', 'item_id']) >> nvt.ops.Categorify()\n",
"\n",
"# create time features\n",
"session_ts = ColumnSelector(['timestamp'])\n",
Expand Down Expand Up @@ -494,18 +494,7 @@
"execution_count": 13,
"id": "45803886",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n",
"/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [<Tags.ITEM: 'item'>, <Tags.ID: 'id'>].\n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"dataset = nvt.Dataset(interactions_merged_df)\n",
"workflow = nvt.Workflow(filtered_sessions)\n",
Expand Down Expand Up @@ -559,7 +548,6 @@
" <th>properties.num_buckets</th>\n",
" <th>properties.freq_threshold</th>\n",
" <th>properties.max_size</th>\n",
" <th>properties.start_index</th>\n",
" <th>properties.cat_path</th>\n",
" <th>properties.domain.min</th>\n",
" <th>properties.domain.max</th>\n",
Expand Down Expand Up @@ -589,7 +577,6 @@
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
Expand All @@ -601,41 +588,39 @@
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>.//categories/unique.item_id.parquet</td>\n",
" <td>0.0</td>\n",
" <td>52740.0</td>\n",
" <td>item_id</td>\n",
" <td>52741.0</td>\n",
" <td>item_id</td>\n",
" <td>52742.0</td>\n",
" <td>512.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>item_id-list</td>\n",
" <td>(Tags.CATEGORICAL, Tags.ITEM_ID, Tags.ITEM, Ta...</td>\n",
" <td>(Tags.CATEGORICAL, Tags.ITEM, Tags.ID, Tags.LIST)</td>\n",
" <td>DType(name='int64', element_type=&lt;ElementType....</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>.//categories/unique.item_id.parquet</td>\n",
" <td>0.0</td>\n",
" <td>52740.0</td>\n",
" <td>item_id</td>\n",
" <td>52741.0</td>\n",
" <td>item_id</td>\n",
" <td>52742.0</td>\n",
" <td>512.0</td>\n",
" <td>0.0</td>\n",
" <td>20.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>et_dayofweek_sin-list</td>\n",
" <td>(Tags.LIST, Tags.CONTINUOUS)</td>\n",
" <td>DType(name='float32', element_type=&lt;ElementTyp...</td>\n",
" <td>(Tags.CONTINUOUS, Tags.LIST)</td>\n",
" <td>DType(name='float64', element_type=&lt;ElementTyp...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>NaN</td>\n",
Expand All @@ -647,14 +632,13 @@
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>20.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>product_recency_days_log_norm-list</td>\n",
" <td>(Tags.LIST, Tags.CONTINUOUS)</td>\n",
" <td>(Tags.CONTINUOUS, Tags.LIST)</td>\n",
" <td>DType(name='float32', element_type=&lt;ElementTyp...</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
Expand All @@ -667,7 +651,6 @@
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>20.0</td>\n",
" </tr>\n",
Expand All @@ -681,12 +664,11 @@
" <td>NaN</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>.//categories/unique.category.parquet</td>\n",
" <td>0.0</td>\n",
" <td>335.0</td>\n",
" <td>category</td>\n",
" <td>336.0</td>\n",
" <td>category</td>\n",
" <td>337.0</td>\n",
" <td>42.0</td>\n",
" <td>0.0</td>\n",
" <td>20.0</td>\n",
Expand All @@ -709,14 +691,13 @@
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"[{'name': 'session_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52740, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52741, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM_ID: 'item_id'>, <Tags.ITEM: 'item'>, <Tags.LIST: 'list'>, <Tags.ID: 'id'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52740, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52741, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {<Tags.LIST: 'list'>, <Tags.CONTINUOUS: 'continuous'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {<Tags.LIST: 'list'>, <Tags.CONTINUOUS: 'continuous'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 1, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 335, 'name': 'category'}, 'embedding_sizes': {'cardinality': 336, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]"
"[{'name': 'session_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>, <Tags.ID: 'id'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=<ElementType.Float: 'float'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]"
]
},
"execution_count": 14,
Expand Down Expand Up @@ -792,24 +773,24 @@
"6606149 11255553 2 \n",
"\n",
" item_id-list \\\n",
"6606147 [604, 878, 742, 90, 4777, 1583, 3446, 8083, 34... \n",
"6606148 [184, 12288] \n",
"6606149 [7299, 1953] \n",
"6606147 [605, 879, 743, 91, 4778, 1584, 3447, 8084, 34... \n",
"6606148 [185, 12289] \n",
"6606149 [7300, 1954] \n",
"\n",
" et_dayofweek_sin-list \\\n",
"6606147 [-0.43388462, -0.43388462, -0.43388462, -0.433... \n",
"6606148 [-0.43388462, -0.43388462] \n",
"6606149 [-0.781831, -0.781831] \n",
"6606147 [-0.43388454782514785, -0.43388454782514785, -... \n",
"6606148 [-0.43388454782514785, -0.43388454782514785] \n",
"6606149 [-0.7818309228245777, -0.7818309228245777] \n",
"\n",
" product_recency_days_log_norm-list \\\n",
"6606147 [1.5241553, 1.5238751, 1.5239341, 1.5241631, 1... \n",
"6606148 [-0.5330064, 1.521494] \n",
"6606149 [1.5338266, 1.5355074] \n",
"\n",
" category-list day_index \n",
"6606147 [3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3] 178 \n",
"6606148 [2, 2] 178 \n",
"6606149 [7, 7] 180 \n"
"6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4] 178 \n",
"6606148 [3, 3] 178 \n",
"6606149 [8, 8] 180 \n"
]
}
],
Expand All @@ -827,7 +808,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Creating time-based splits: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.76it/s]\n"
"Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.37it/s]\n"
]
}
],
Expand All @@ -849,7 +830,7 @@
{
"data": {
"text/plain": [
"570"
"583"
]
},
"execution_count": 19,
Expand Down
Loading

0 comments on commit 2eb6da1

Please sign in to comment.