diff --git a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb index 8c802b4686..120b91a771 100644 --- a/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/getting-started-session-based/01-ETL-with-NVTabular.ipynb @@ -64,8 +64,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] @@ -187,48 +185,48 @@ " \n", " \n", " 0\n", - " 88469\n", - " 7\n", - " 2\n", - " 0.876678\n", - " 0.330926\n", + " 75772\n", " 3\n", + " 1\n", + " 0.893401\n", + " 0.830613\n", + " 7\n", " \n", " \n", " 1\n", - " 83904\n", - " 120\n", + " 82179\n", + " 77\n", " 21\n", - " 0.258138\n", - " 0.422327\n", - " 6\n", + " 0.892670\n", + " 0.745608\n", + " 5\n", " \n", " \n", " 2\n", - " 81493\n", - " 42\n", - " 8\n", - " 0.234976\n", - " 0.686702\n", - " 6\n", + " 83356\n", + " 19\n", + " 5\n", + " 0.189608\n", + " 0.011347\n", + " 5\n", " \n", " \n", " 3\n", - " 73668\n", - " 182\n", - " 32\n", - " 0.498537\n", - " 0.121023\n", - " 8\n", + " 88757\n", + " 177\n", + " 48\n", + " 0.059060\n", + " 0.771164\n", + " 7\n", " \n", " \n", " 4\n", - " 74926\n", - " 59\n", - " 11\n", - " 0.126336\n", - " 0.765945\n", - " 1\n", + " 82165\n", + " 20\n", + " 6\n", + " 0.910964\n", + " 0.449554\n", + " 3\n", " \n", " \n", "\n", @@ -236,11 +234,11 @@ ], "text/plain": [ " session_id item_id category age_days weekday_sin day\n", - "0 88469 7 2 0.876678 0.330926 3\n", - "1 83904 120 21 0.258138 0.422327 6\n", - "2 81493 42 8 0.234976 0.686702 6\n", - "3 73668 182 32 0.498537 0.121023 8\n", - "4 74926 59 11 0.126336 0.765945 1" + "0 75772 3 1 0.893401 0.830613 7\n", + "1 82179 77 21 0.892670 0.745608 5\n", + "2 83356 19 5 0.189608 0.011347 5\n", + "3 88757 177 48 0.059060 0.771164 7\n", + "4 82165 20 6 0.910964 0.449554 3" ] }, "execution_count": 6, @@ -319,17 +317,17 @@ "# Select and truncate the sequential features\n", "sequence_features_truncated = (\n", " groupby_features['category-list']\n", - " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) \n", + " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH, pad=True) \n", ")\n", "\n", "sequence_features_truncated_item = (\n", " groupby_features['item_id-list']\n", - " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) \n", + " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH, pad=True) \n", " >> TagAsItemID()\n", ") \n", "sequence_features_truncated_cont = (\n", " groupby_features['age_days-list', 'weekday_sin-list'] \n", - " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH) \n", + " >> nvt.ops.ListSlice(-SESSIONS_MAX_LENGTH, pad=True) \n", " >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])\n", ")\n", "\n", @@ -398,32 +396,32 @@ " \n", " 0\n", " 1\n", - " 1\n", - " 19\n", - " [27, 26, 7, 46, 13, 2, 4, 237, 10, 35, 46, 35,...\n", - " [5, 5, 2, 8, 2, 1, 2, 45, 3, 6, 8, 6, 1, 16, 3...\n", - " [0.97853184, 0.4591664, 0.083990775, 0.7000025...\n", - " [0.04896013, 0.18139902, 0.5046173, 0.48253214...\n", + " 7\n", + " 16\n", + " [10, 33, 35, 68, 19, 4, 6, 4, 37, 104, 19, 30,...\n", + " [5, 9, 11, 21, 4, 2, 2, 2, 11, 29, 4, 9, 7, 14...\n", + " [0.2510539, 0.12130147, 0.61642516, 0.45710337...\n", + " [0.66570914, 0.6149484, 0.98552155, 0.10168565...\n", " \n", " \n", " 1\n", " 2\n", - " 4\n", - " 18\n", - " [45, 7, 17, 17, 44, 17, 35, 23, 1, 194, 13, 18...\n", - " [8, 2, 4, 4, 8, 4, 6, 4, 2, 36, 2, 4, 2, 7, 7,...\n", - " [0.4287106, 0.75681955, 0.70978284, 0.8613602,...\n", - " [0.88505965, 0.07815777, 0.8684893, 0.25832322...\n", + " 9\n", + " 14\n", + " [50, 99, 18, 39, 47, 43, 5, 75, 3, 8, 163, 30,...\n", + " [14, 26, 6, 10, 14, 12, 1, 19, 1, 4, 45, 9, 2,...\n", + " [0.17742382, 0.81522274, 0.75508606, 0.1395472...\n", + " [0.29890507, 0.6564371, 0.96094626, 0.6960773,...\n", " \n", " \n", " 2\n", " 3\n", " 5\n", - " 15\n", - " [27, 38, 10, 64, 84, 109, 152, 37, 13, 158, 3,...\n", - " [5, 7, 3, 11, 15, 19, 29, 7, 2, 27, 1, 4, 5, 2...\n", - " [0.07284722, 0.48706728, 0.09015047, 0.4159043...\n", - " [0.4339925, 0.51614755, 0.3565242, 0.5322814, ...\n", + " 14\n", + " [26, 15, 51, 12, 40, 9, 54, 57, 8, 376, 57, 24...\n", + " [7, 3, 13, 4, 10, 2, 15, 15, 4, 136, 15, 8, 6,...\n", + " [0.21535604, 0.76454645, 0.82518786, 0.0410606...\n", + " [0.037031054, 0.3980902, 0.95815617, 0.7962937...\n", " \n", " \n", "\n", @@ -431,29 +429,29 @@ ], "text/plain": [ " session_id day-first item_id-count \\\n", - "0 1 1 19 \n", - "1 2 4 18 \n", - "2 3 5 15 \n", + "0 1 7 16 \n", + "1 2 9 14 \n", + "2 3 5 14 \n", "\n", " item_id-list \\\n", - "0 [27, 26, 7, 46, 13, 2, 4, 237, 10, 35, 46, 35,... \n", - "1 [45, 7, 17, 17, 44, 17, 35, 23, 1, 194, 13, 18... \n", - "2 [27, 38, 10, 64, 84, 109, 152, 37, 13, 158, 3,... \n", + "0 [10, 33, 35, 68, 19, 4, 6, 4, 37, 104, 19, 30,... \n", + "1 [50, 99, 18, 39, 47, 43, 5, 75, 3, 8, 163, 30,... \n", + "2 [26, 15, 51, 12, 40, 9, 54, 57, 8, 376, 57, 24... \n", "\n", " category-list \\\n", - "0 [5, 5, 2, 8, 2, 1, 2, 45, 3, 6, 8, 6, 1, 16, 3... \n", - "1 [8, 2, 4, 4, 8, 4, 6, 4, 2, 36, 2, 4, 2, 7, 7,... \n", - "2 [5, 7, 3, 11, 15, 19, 29, 7, 2, 27, 1, 4, 5, 2... \n", + "0 [5, 9, 11, 21, 4, 2, 2, 2, 11, 29, 4, 9, 7, 14... \n", + "1 [14, 26, 6, 10, 14, 12, 1, 19, 1, 4, 45, 9, 2,... \n", + "2 [7, 3, 13, 4, 10, 2, 15, 15, 4, 136, 15, 8, 6,... \n", "\n", " age_days-list \\\n", - "0 [0.97853184, 0.4591664, 0.083990775, 0.7000025... \n", - "1 [0.4287106, 0.75681955, 0.70978284, 0.8613602,... \n", - "2 [0.07284722, 0.48706728, 0.09015047, 0.4159043... \n", + "0 [0.2510539, 0.12130147, 0.61642516, 0.45710337... \n", + "1 [0.17742382, 0.81522274, 0.75508606, 0.1395472... \n", + "2 [0.21535604, 0.76454645, 0.82518786, 0.0410606... \n", "\n", " weekday_sin-list \n", - "0 [0.04896013, 0.18139902, 0.5046173, 0.48253214... \n", - "1 [0.88505965, 0.07815777, 0.8684893, 0.25832322... \n", - "2 [0.4339925, 0.51614755, 0.3565242, 0.5322814, ... " + "0 [0.66570914, 0.6149484, 0.98552155, 0.10168565... \n", + "1 [0.29890507, 0.6564371, 0.96094626, 0.6960773,... \n", + "2 [0.037031054, 0.3980902, 0.95815617, 0.7962937... " ] }, "execution_count": 8, @@ -533,9 +531,9 @@ " 0.0\n", " .//categories/unique.session_id.parquet\n", " 0.0\n", - " 19871.0\n", + " 19855.0\n", " session_id\n", - " 19872.0\n", + " 19856.0\n", " 408.0\n", " NaN\n", " NaN\n", @@ -573,9 +571,9 @@ " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 502.0\n", + " 495.0\n", " item_id\n", - " 503.0\n", + " 496.0\n", " 52.0\n", " NaN\n", " NaN\n", @@ -583,42 +581,42 @@ " \n", " 3\n", " item_id-list\n", - " (Tags.CATEGORICAL, Tags.ITEM_ID, Tags.ID, Tags...\n", + " (Tags.ITEM, Tags.ID, Tags.ITEM_ID, Tags.LIST, ...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", - " True\n", + " False\n", " NaN\n", " 0.0\n", " 0.0\n", " 0.0\n", " .//categories/unique.item_id.parquet\n", " 0.0\n", - " 502.0\n", + " 495.0\n", " item_id\n", - " 503.0\n", + " 496.0\n", " 52.0\n", - " 2.0\n", - " 19.0\n", + " 20.0\n", + " 20.0\n", " \n", " \n", " 4\n", " category-list\n", - " (Tags.LIST, Tags.CATEGORICAL)\n", + " (Tags.CATEGORICAL, Tags.LIST)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", - " True\n", + " False\n", " NaN\n", " 0.0\n", " 0.0\n", " 0.0\n", " .//categories/unique.category.parquet\n", " 0.0\n", - " 125.0\n", + " 178.0\n", " category\n", - " 126.0\n", - " 24.0\n", - " 2.0\n", - " 19.0\n", + " 179.0\n", + " 29.0\n", + " 20.0\n", + " 20.0\n", " \n", " \n", " 5\n", @@ -626,7 +624,7 @@ " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", - " True\n", + " False\n", " NaN\n", " NaN\n", " NaN\n", @@ -637,8 +635,8 @@ " NaN\n", " NaN\n", " NaN\n", - " 2.0\n", - " 19.0\n", + " 20.0\n", + " 20.0\n", " \n", " \n", " 6\n", @@ -646,7 +644,7 @@ " (Tags.LIST, Tags.CONTINUOUS)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", - " True\n", + " False\n", " NaN\n", " NaN\n", " NaN\n", @@ -657,15 +655,15 @@ " NaN\n", " NaN\n", " NaN\n", - " 2.0\n", - " 19.0\n", + " 20.0\n", + " 20.0\n", " \n", " \n", "\n", "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.session_id.parquet', 'domain': {'min': 0, 'max': 19871, 'name': 'session_id'}, 'embedding_sizes': {'cardinality': 19872, 'dimension': 408}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 502, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 503, 'dimension': 52}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 502, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 503, 'dimension': 52}, 'value_count': {'min': 2, 'max': 19}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 125, 'name': 'category'}, 'embedding_sizes': {'cardinality': 126, 'dimension': 24}, 'value_count': {'min': 2, 'max': 19}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True), 'is_list': True, 'is_ragged': True}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 19}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True), 'is_list': True, 'is_ragged': True}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 2, 'max': 19}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True), 'is_list': True, 'is_ragged': True}]" + "[{'name': 'session_id', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.session_id.parquet', 'domain': {'min': 0, 'max': 19855, 'name': 'session_id'}, 'embedding_sizes': {'cardinality': 19856, 'dimension': 408}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'day-first', 'tags': set(), 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 495, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 496, 'dimension': 52}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 495, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 496, 'dimension': 52}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'start_index': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 178, 'name': 'category'}, 'embedding_sizes': {'cardinality': 179, 'dimension': 29}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'age_days-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, "execution_count": 9, @@ -750,7 +748,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 27.95it/s]\n" + "Creating time-based splits: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 19.93it/s]\n" ] } ], @@ -819,48 +817,48 @@ " \n", " \n", " 0\n", - " 1\n", - " 19\n", - " [27, 26, 7, 46, 13, 2, 4, 237, 10, 35, 46, 35,...\n", - " [5, 5, 2, 8, 2, 1, 2, 45, 3, 6, 8, 6, 1, 16, 3...\n", - " [0.97853184, 0.4591664, 0.083990775, 0.7000025...\n", - " [0.04896013, 0.18139902, 0.5046173, 0.48253214...\n", + " 6\n", + " 14\n", + " [7, 11, 73, 6, 31, 5, 19, 63, 52, 1, 28, 19, 2...\n", + " [2, 5, 19, 2, 9, 1, 4, 17, 13, 1, 7, 4, 8, 5, ...\n", + " [0.84568787, 0.038363576, 0.7171949, 0.0886422...\n", + " [0.9072822, 0.55461484, 0.2662152, 0.6641106, ...\n", " \n", " \n", " 1\n", - " 17\n", - " 13\n", - " [15, 5, 5, 58, 8, 18, 29, 34, 2, 3, 43, 54, 9]\n", - " [3, 1, 1, 10, 1, 4, 6, 6, 1, 1, 8, 10, 1]\n", - " [0.76496226, 0.85960853, 0.13536207, 0.3988903...\n", - " [0.3081522, 0.17396946, 0.8448347, 0.8297997, ...\n", + " 9\n", + " 14\n", + " [42, 22, 30, 26, 19, 9, 53, 5, 51, 5, 19, 3, 2...\n", + " [12, 6, 9, 7, 4, 2, 15, 1, 13, 1, 4, 1, 8, 3, ...\n", + " [0.4074032, 0.7792388, 0.49303588, 0.027537243...\n", + " [0.65899414, 0.42423004, 0.20023833, 0.6077999...\n", " \n", " \n", " 2\n", - " 34\n", + " 14\n", " 13\n", - " [17, 12, 9, 21, 29, 6, 23, 6, 5, 176, 12, 26, 1]\n", - " [4, 3, 1, 4, 6, 1, 4, 1, 1, 29, 3, 5, 2]\n", - " [0.42529476, 0.66954064, 0.46188155, 0.2200255...\n", - " [0.951742, 0.7311401, 0.6795269, 0.5283087, 0....\n", + " [7, 60, 2, 7, 28, 2, 25, 24, 151, 74, 112, 31,...\n", + " [2, 16, 1, 2, 7, 1, 8, 8, 40, 24, 29, 9, 17, 0...\n", + " [0.9137222, 0.77429664, 0.4397028, 0.41606435,...\n", + " [0.3428851, 0.9583178, 0.07852303, 0.8921527, ...\n", " \n", " \n", " 4\n", - " 58\n", + " 39\n", " 12\n", - " [84, 11, 7, 66, 23, 1, 36, 5, 19, 22, 6, 22]\n", - " [15, 3, 2, 12, 4, 2, 7, 1, 4, 4, 1, 4]\n", - " [0.7655469, 0.4924979, 0.9192873, 0.6521773, 0...\n", - " [0.060284566, 0.9057582, 0.9853312, 0.27452144...\n", + " [67, 1, 16, 31, 21, 9, 14, 3, 8, 22, 23, 50, 0...\n", + " [17, 1, 3, 9, 6, 2, 3, 1, 4, 6, 8, 14, 0, 0, 0...\n", + " [0.7679332, 0.7644972, 0.8533882, 0.67827713, ...\n", + " [0.87136024, 0.92441916, 0.27371496, 0.4557360...\n", " \n", " \n", " 5\n", - " 64\n", + " 52\n", " 12\n", - " [12, 7, 6, 5, 26, 20, 90, 28, 132, 36, 21, 8]\n", - " [3, 2, 1, 1, 5, 2, 16, 5, 23, 7, 4, 1]\n", - " [0.86268437, 0.11732827, 0.31621945, 0.0408642...\n", - " [0.8027563, 0.7638514, 0.055432655, 0.06549974...\n", + " [31, 17, 49, 13, 49, 16, 23, 85, 23, 164, 28, ...\n", + " [9, 3, 13, 5, 13, 3, 8, 23, 8, 51, 7, 2, 0, 0,...\n", + " [0.32460424, 0.9527502, 0.77985513, 0.91916, 0...\n", + " [0.12728073, 0.87657094, 0.7073715, 0.9970732,...\n", " \n", " \n", " ...\n", @@ -872,122 +870,122 @@ " ...\n", " \n", " \n", - " 2110\n", - " 19074\n", + " 2145\n", + " 19158\n", " 2\n", - " [10, 16]\n", - " [3, 3]\n", - " [0.9952336, 0.018463716]\n", - " [0.3855745, 0.8623388]\n", + " [34, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [0.44386843, 0.17579898, 0.0, 0.0, 0.0, 0.0, 0...\n", + " [0.58763367, 0.997146, 0.0, 0.0, 0.0, 0.0, 0.0...\n", " \n", " \n", - " 2111\n", - " 19122\n", + " 2146\n", + " 19165\n", " 2\n", - " [37, 28]\n", - " [7, 5]\n", - " [0.26565734, 0.3376144]\n", - " [0.519952, 0.117240556]\n", + " [1, 60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [1, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.45839304, 0.15023704, 0.0, 0.0, 0.0, 0.0, 0...\n", + " [0.47192892, 0.6211317, 0.0, 0.0, 0.0, 0.0, 0....\n", " \n", " \n", - " 2112\n", - " 19128\n", + " 2148\n", + " 19183\n", " 2\n", - " [18, 15]\n", - " [4, 3]\n", - " [0.65739745, 0.46439078]\n", - " [0.49096248, 0.5064814]\n", + " [23, 29, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...\n", + " [8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [0.7376038, 0.7187783, 0.0, 0.0, 0.0, 0.0, 0.0...\n", + " [0.4954509, 0.5675057, 0.0, 0.0, 0.0, 0.0, 0.0...\n", " \n", " \n", - " 2113\n", - " 19134\n", + " 2149\n", + " 19199\n", " 2\n", - " [9, 116]\n", - " [1, 20]\n", - " [0.45008472, 0.36275008]\n", - " [0.10166882, 0.8127918]\n", + " [52, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [13, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...\n", + " [0.96259063, 0.8100127, 0.0, 0.0, 0.0, 0.0, 0....\n", + " [0.3484375, 0.10194607, 0.0, 0.0, 0.0, 0.0, 0....\n", " \n", " \n", - " 2114\n", - " 19136\n", + " 2150\n", + " 19221\n", " 2\n", - " [6, 9]\n", - " [1, 1]\n", - " [0.61278456, 0.64234763]\n", - " [0.46083343, 0.8074532]\n", + " [3, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...\n", + " [0.9268296, 0.71968925, 0.0, 0.0, 0.0, 0.0, 0....\n", + " [0.8299869, 0.7187812, 0.0, 0.0, 0.0, 0.0, 0.0...\n", " \n", " \n", "\n", - "

1686 rows × 6 columns

\n", + "

1718 rows × 6 columns

\n", "" ], "text/plain": [ " session_id item_id-count \\\n", - "0 1 19 \n", - "1 17 13 \n", - "2 34 13 \n", - "4 58 12 \n", - "5 64 12 \n", + "0 6 14 \n", + "1 9 14 \n", + "2 14 13 \n", + "4 39 12 \n", + "5 52 12 \n", "... ... ... \n", - "2110 19074 2 \n", - "2111 19122 2 \n", - "2112 19128 2 \n", - "2113 19134 2 \n", - "2114 19136 2 \n", + "2145 19158 2 \n", + "2146 19165 2 \n", + "2148 19183 2 \n", + "2149 19199 2 \n", + "2150 19221 2 \n", "\n", " item_id-list \\\n", - "0 [27, 26, 7, 46, 13, 2, 4, 237, 10, 35, 46, 35,... \n", - "1 [15, 5, 5, 58, 8, 18, 29, 34, 2, 3, 43, 54, 9] \n", - "2 [17, 12, 9, 21, 29, 6, 23, 6, 5, 176, 12, 26, 1] \n", - "4 [84, 11, 7, 66, 23, 1, 36, 5, 19, 22, 6, 22] \n", - "5 [12, 7, 6, 5, 26, 20, 90, 28, 132, 36, 21, 8] \n", + "0 [7, 11, 73, 6, 31, 5, 19, 63, 52, 1, 28, 19, 2... \n", + "1 [42, 22, 30, 26, 19, 9, 53, 5, 51, 5, 19, 3, 2... \n", + "2 [7, 60, 2, 7, 28, 2, 25, 24, 151, 74, 112, 31,... \n", + "4 [67, 1, 16, 31, 21, 9, 14, 3, 8, 22, 23, 50, 0... \n", + "5 [31, 17, 49, 13, 49, 16, 23, 85, 23, 164, 28, ... \n", "... ... \n", - "2110 [10, 16] \n", - "2111 [37, 28] \n", - "2112 [18, 15] \n", - "2113 [9, 116] \n", - "2114 [6, 9] \n", + "2145 [34, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2146 [1, 60, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2148 [23, 29, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0... \n", + "2149 [52, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2150 [3, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "\n", " category-list \\\n", - "0 [5, 5, 2, 8, 2, 1, 2, 45, 3, 6, 8, 6, 1, 16, 3... \n", - "1 [3, 1, 1, 10, 1, 4, 6, 6, 1, 1, 8, 10, 1] \n", - "2 [4, 3, 1, 4, 6, 1, 4, 1, 1, 29, 3, 5, 2] \n", - "4 [15, 3, 2, 12, 4, 2, 7, 1, 4, 4, 1, 4] \n", - "5 [3, 2, 1, 1, 5, 2, 16, 5, 23, 7, 4, 1] \n", + "0 [2, 5, 19, 2, 9, 1, 4, 17, 13, 1, 7, 4, 8, 5, ... \n", + "1 [12, 6, 9, 7, 4, 2, 15, 1, 13, 1, 4, 1, 8, 3, ... \n", + "2 [2, 16, 1, 2, 7, 1, 8, 8, 40, 24, 29, 9, 17, 0... \n", + "4 [17, 1, 3, 9, 6, 2, 3, 1, 4, 6, 8, 14, 0, 0, 0... \n", + "5 [9, 3, 13, 5, 13, 3, 8, 23, 8, 51, 7, 2, 0, 0,... \n", "... ... \n", - "2110 [3, 3] \n", - "2111 [7, 5] \n", - "2112 [4, 3] \n", - "2113 [1, 20] \n", - "2114 [1, 1] \n", + "2145 [9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "2146 [1, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2148 [8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", + "2149 [13, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,... \n", + "2150 [1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... \n", "\n", " age_days-list \\\n", - "0 [0.97853184, 0.4591664, 0.083990775, 0.7000025... \n", - "1 [0.76496226, 0.85960853, 0.13536207, 0.3988903... \n", - "2 [0.42529476, 0.66954064, 0.46188155, 0.2200255... \n", - "4 [0.7655469, 0.4924979, 0.9192873, 0.6521773, 0... \n", - "5 [0.86268437, 0.11732827, 0.31621945, 0.0408642... \n", + "0 [0.84568787, 0.038363576, 0.7171949, 0.0886422... \n", + "1 [0.4074032, 0.7792388, 0.49303588, 0.027537243... \n", + "2 [0.9137222, 0.77429664, 0.4397028, 0.41606435,... \n", + "4 [0.7679332, 0.7644972, 0.8533882, 0.67827713, ... \n", + "5 [0.32460424, 0.9527502, 0.77985513, 0.91916, 0... \n", "... ... \n", - "2110 [0.9952336, 0.018463716] \n", - "2111 [0.26565734, 0.3376144] \n", - "2112 [0.65739745, 0.46439078] \n", - "2113 [0.45008472, 0.36275008] \n", - "2114 [0.61278456, 0.64234763] \n", + "2145 [0.44386843, 0.17579898, 0.0, 0.0, 0.0, 0.0, 0... \n", + "2146 [0.45839304, 0.15023704, 0.0, 0.0, 0.0, 0.0, 0... \n", + "2148 [0.7376038, 0.7187783, 0.0, 0.0, 0.0, 0.0, 0.0... \n", + "2149 [0.96259063, 0.8100127, 0.0, 0.0, 0.0, 0.0, 0.... \n", + "2150 [0.9268296, 0.71968925, 0.0, 0.0, 0.0, 0.0, 0.... \n", "\n", " weekday_sin-list \n", - "0 [0.04896013, 0.18139902, 0.5046173, 0.48253214... \n", - "1 [0.3081522, 0.17396946, 0.8448347, 0.8297997, ... \n", - "2 [0.951742, 0.7311401, 0.6795269, 0.5283087, 0.... \n", - "4 [0.060284566, 0.9057582, 0.9853312, 0.27452144... \n", - "5 [0.8027563, 0.7638514, 0.055432655, 0.06549974... \n", + "0 [0.9072822, 0.55461484, 0.2662152, 0.6641106, ... \n", + "1 [0.65899414, 0.42423004, 0.20023833, 0.6077999... \n", + "2 [0.3428851, 0.9583178, 0.07852303, 0.8921527, ... \n", + "4 [0.87136024, 0.92441916, 0.27371496, 0.4557360... \n", + "5 [0.12728073, 0.87657094, 0.7073715, 0.9970732,... \n", "... ... \n", - "2110 [0.3855745, 0.8623388] \n", - "2111 [0.519952, 0.117240556] \n", - "2112 [0.49096248, 0.5064814] \n", - "2113 [0.10166882, 0.8127918] \n", - "2114 [0.46083343, 0.8074532] \n", + "2145 [0.58763367, 0.997146, 0.0, 0.0, 0.0, 0.0, 0.0... \n", + "2146 [0.47192892, 0.6211317, 0.0, 0.0, 0.0, 0.0, 0.... \n", + "2148 [0.4954509, 0.5675057, 0.0, 0.0, 0.0, 0.0, 0.0... \n", + "2149 [0.3484375, 0.10194607, 0.0, 0.0, 0.0, 0.0, 0.... \n", + "2150 [0.8299869, 0.7187812, 0.0, 0.0, 0.0, 0.0, 0.0... \n", "\n", - "[1686 rows x 6 columns]" + "[1718 rows x 6 columns]" ] }, "execution_count": 15, @@ -1009,7 +1007,7 @@ { "data": { "text/plain": [ - "490" + "565" ] }, "execution_count": 16, diff --git a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb index 1ac1435603..c46ab0dab2 100644 --- a/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb +++ b/examples/getting-started-session-based/03-serving-session-based-model-torch-backend.ipynb @@ -70,8 +70,56 @@ "text": [ "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n" + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (NDCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (DCGAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (AvgPrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (PrecisionAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n", + "/usr/local/lib/python3.8/dist-packages/torchmetrics/utilities/prints.py:36: UserWarning: Torchmetrics v0.9 introduced a new argument class property called `full_state_update` that has\n", + " not been set for this class (RecallAt). The property determines if `update` by\n", + " default needs access to the full metric state. If this is not the case, significant speedups can be\n", + " achieved and we recommend setting this to `False`.\n", + " We provide an checking function\n", + " `from torchmetrics.utilities import check_forward_full_state_property`\n", + " that can be used to check if the `full_state_update=True` (old and potential slower behaviour,\n", + " default for now) or if `full_state_update=False` can be used safely.\n", + " \n", + " warnings.warn(*args, **kwargs)\n" ] } ], @@ -201,8 +249,8 @@ " (categorical_module): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(503, 64, padding_idx=0)\n", - " (category-list): Embedding(126, 64, padding_idx=0)\n", + " (item_id-list): Embedding(496, 64, padding_idx=0)\n", + " (category-list): Embedding(179, 64, padding_idx=0)\n", " )\n", " )\n", " )\n", @@ -235,6 +283,7 @@ " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", " (dropout): Dropout(p=0.3, inplace=False)\n", + " (activation_function): GELUActivation()\n", " )\n", " (dropout): Dropout(p=0.3, inplace=False)\n", " )\n", @@ -248,6 +297,7 @@ " (layer_1): Linear(in_features=64, out_features=256, bias=True)\n", " (layer_2): Linear(in_features=256, out_features=64, bias=True)\n", " (dropout): Dropout(p=0.3, inplace=False)\n", + " (activation_function): GELUActivation()\n", " )\n", " (dropout): Dropout(p=0.3, inplace=False)\n", " )\n", @@ -273,15 +323,15 @@ " (embeddings): SequenceEmbeddingFeatures(\n", " (filter_features): FilterFeatures()\n", " (embedding_tables): ModuleDict(\n", - " (item_id-list): Embedding(503, 64, padding_idx=0)\n", - " (category-list): Embedding(126, 64, padding_idx=0)\n", + " (item_id-list): Embedding(496, 64, padding_idx=0)\n", + " (category-list): Embedding(179, 64, padding_idx=0)\n", " )\n", " )\n", - " (item_embedding_table): Embedding(503, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(496, 64, padding_idx=0)\n", " (masking): MaskedLanguageModeling()\n", " (pre): Block(\n", " (module): NextItemPredictionTask(\n", - " (item_embedding_table): Embedding(503, 64, padding_idx=0)\n", + " (item_embedding_table): Embedding(496, 64, padding_idx=0)\n", " (log_softmax): LogSoftmax(dim=-1)\n", " )\n", " )\n", @@ -386,13 +436,13 @@ { "data": { "text/plain": [ - "tensor([[27, 26, 7, ..., 32, 14, 0],\n", - " [15, 5, 5, ..., 0, 0, 0],\n", - " [17, 12, 9, ..., 0, 0, 0],\n", + "tensor([[ 7, 11, 73, ..., 0, 0, 0],\n", + " [ 42, 22, 30, ..., 0, 0, 0],\n", + " [ 7, 60, 2, ..., 0, 0, 0],\n", " ...,\n", - " [30, 13, 21, ..., 0, 0, 0],\n", - " [19, 14, 8, ..., 0, 0, 0],\n", - " [11, 27, 16, ..., 0, 0, 0]], device='cuda:0')" + " [ 18, 37, 18, ..., 0, 0, 0],\n", + " [ 12, 19, 33, ..., 0, 0, 0],\n", + " [ 11, 16, 102, ..., 0, 0, 0]], device='cuda:0')" ] }, "execution_count": 10, @@ -445,6 +495,19 @@ { "cell_type": "code", "execution_count": 13, + "id": "c215a81a-dec7-466b-aeb5-1e698f0b021f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for col_name, col_schema in input_schema.column_schemas.items():\n", + " input_schema[col_name] = input_schema[col_name].with_shape((None, sparse_max[col_name]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, "id": "757cd0c5-f581-488b-a8de-b8d1188820d6", "metadata": {}, "outputs": [ @@ -476,38 +539,46 @@ " is_ragged\n", " properties.int_domain.min\n", " properties.int_domain.max\n", + " properties.value_count.min\n", + " properties.value_count.max\n", " \n", " \n", " \n", " \n", " 0\n", " age_days-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", " 0\n", " 0\n", + " 20\n", + " 20\n", " \n", " \n", " 1\n", " weekday_sin-list\n", - " (Tags.LIST, Tags.CONTINUOUS)\n", + " (Tags.CONTINUOUS, Tags.LIST)\n", " DType(name='float32', element_type=<ElementTyp...\n", " True\n", " False\n", " 0\n", " 0\n", + " 20\n", + " 20\n", " \n", " \n", " 2\n", " item_id-list\n", - " (Tags.CATEGORICAL, Tags.ITEM_ID, Tags.ITEM, Ta...\n", + " (Tags.ITEM_ID, Tags.ID, Tags.LIST, Tags.CATEGO...\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " False\n", " 0\n", - " 502\n", + " 495\n", + " 20\n", + " 20\n", " \n", " \n", " 3\n", @@ -517,17 +588,19 @@ " True\n", " False\n", " 0\n", - " 125\n", + " 178\n", + " 20\n", + " 20\n", " \n", " \n", "\n", "" ], "text/plain": [ - "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 502}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 125}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True), 'is_list': True, 'is_ragged': False}]" + "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 495}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 178}, 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -546,7 +619,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "6b2deb2b-e223-4b5d-b655-810e1aefa7e8", "metadata": {}, "outputs": [], @@ -569,7 +642,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "4f96597c-1c05-4fb0-ad3e-c55c21599158", "metadata": {}, "outputs": [], @@ -591,7 +664,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "id": "5b0d14bb-7765-45e8-8fd0-9d508dc3ec14", "metadata": {}, "outputs": [], @@ -600,6 +673,117 @@ "ens_config, node_configs = ensemble.export(ens_model_path)" ] }, + { + "cell_type": "code", + "execution_count": 18, + "id": "a3ba86eb-ca25-4a0c-9daf-61c9911b29ab", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nametagsdtypeis_listis_raggedproperties.int_domain.minproperties.int_domain.maxproperties.triton_scalar_shapeproperties.value_count.minproperties.value_count.max
0age_days-list(Tags.CONTINUOUS, Tags.LIST)DType(name='float32', element_type=<ElementTyp...TrueFalse00[]2020
1weekday_sin-list(Tags.CONTINUOUS, Tags.LIST)DType(name='float32', element_type=<ElementTyp...TrueFalse00[]2020
2item_id-list(Tags.ITEM_ID, Tags.ID, Tags.LIST, Tags.CATEGO...DType(name='int64', element_type=<ElementType....TrueFalse0495[]2020
3category-list(Tags.LIST, Tags.CATEGORICAL)DType(name='int64', element_type=<ElementType....TrueFalse0178[]2020
\n", + "
" + ], + "text/plain": [ + "[{'name': 'age_days-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}, 'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'weekday_sin-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 0}, 'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'int_domain': {'min': 0, 'max': 495}, 'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}, {'name': 'category-list', 'tags': {, }, 'properties': {'int_domain': {'min': 0, 'max': 178}, 'triton_scalar_shape': [], 'value_count': {'min': 20, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=20, max=20)))), 'is_list': True, 'is_ragged': False}]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ensemble.input_schema" + ] + }, { "cell_type": "markdown", "id": "a36169a5-f218-44b5-b034-7d299ce718ed", @@ -622,7 +806,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "id": "46a86c8d-9ec1-4422-8f8c-4d49e83f6783", "metadata": {}, "outputs": [ @@ -655,7 +839,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "id": "dda3f852-a019-4bf1-831b-f63b750a1192", "metadata": {}, "outputs": [ @@ -678,7 +862,7 @@ " {'name': 'executor_model', 'version': '1', 'state': 'READY'}]" ] }, - "execution_count": 18, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -710,7 +894,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "0acd5649-31fe-4f3f-87a2-2607477638b5", "metadata": {}, "outputs": [], @@ -724,18 +908,18 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "id": "1a4894de-939f-4c3b-8c76-6f4d6f91d787", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([ 14, 71, 45, 35, 140, 89, 7, 115, 196, 19, 2, 10, 0, 0,\n", + "tensor([ 8, 8, 10, 89, 130, 71, 16, 10, 9, 8, 8, 40, 0, 0,\n", " 0, 0, 0, 0, 0, 0], device='cuda:0')" ] }, - "execution_count": 21, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -746,7 +930,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "id": "0306fc5a-5f54-4a58-b762-97b38908b290", "metadata": {}, "outputs": [ @@ -787,38 +971,38 @@ " \n", " \n", " 0\n", - " [0.9509504, 0.3658292, 0.10605793, 0.8901615, ...\n", - " [0.9222485, 0.1284022, 0.92028487, 0.3788347, ...\n", - " [14, 71, 45, 35, 140, 89, 7, 115, 196, 19, 2, ...\n", - " [3, 14, 8, 6, 27, 16, 2, 20, 31, 4, 1, 3, 0, 0...\n", + " [0.9501408, 0.49153143, 0.43261203, 0.60583574...\n", + " [0.50088716, 0.8589678, 0.25768423, 0.6750298,...\n", + " [8, 8, 10, 89, 130, 71, 16, 10, 9, 8, 8, 40, 0...\n", + " [4, 4, 5, 27, 38, 19, 3, 5, 2, 4, 4, 10, 0, 0,...\n", " \n", " \n", " 1\n", - " [0.23776619, 0.062151734, 0.059320305, 0.37635...\n", - " [0.75332737, 0.18823138, 0.5440263, 0.27081072...\n", - " [6, 4, 42, 97, 208, 5, 50, 45, 7, 2, 0, 0, 0, ...\n", - " [1, 2, 7, 18, 34, 1, 9, 8, 2, 1, 0, 0, 0, 0, 0...\n", + " [0.8527362, 0.12124015, 0.882591, 0.45782763, ...\n", + " [0.47754368, 0.9635016, 0.2031468, 0.41193682,...\n", + " [5, 67, 25, 1, 30, 15, 10, 67, 38, 13, 19, 0, ...\n", + " [1, 17, 8, 1, 9, 3, 5, 17, 10, 5, 4, 0, 0, 0, ...\n", " \n", " \n", " 2\n", - " [0.6510976, 0.002470178, 0.19554594, 0.6035013...\n", - " [0.0155129675, 0.067784436, 0.6556247, 0.90605...\n", - " [25, 38, 126, 2, 14, 10, 8, 14, 16, 28, 0, 0, ...\n", - " [5, 7, 21, 1, 3, 3, 1, 3, 3, 5, 0, 0, 0, 0, 0,...\n", + " [0.40404844, 0.59667224, 0.011392503, 0.209335...\n", + " [0.9317083, 0.071744286, 0.958608, 0.24852555,...\n", + " [107, 13, 28, 18, 12, 91, 12, 2, 98, 1, 0, 0, ...\n", + " [33, 5, 7, 6, 4, 23, 4, 1, 26, 1, 0, 0, 0, 0, ...\n", " \n", " \n", " 3\n", - " [0.62920743, 0.7574743, 0.1393074, 0.14867006,...\n", - " [0.44066542, 0.6632927, 0.51982445, 0.8328001,...\n", - " [4, 12, 26, 19, 23, 124, 22, 2, 50, 38, 0, 0, ...\n", - " [2, 3, 5, 4, 4, 22, 4, 1, 9, 7, 0, 0, 0, 0, 0,...\n", + " [0.10275215, 0.5568824, 0.5089987, 0.14826113,...\n", + " [0.6575956, 0.71825516, 0.5113613, 0.3175862, ...\n", + " [8, 27, 10, 42, 6, 14, 12, 164, 77, 32, 0, 0, ...\n", + " [4, 7, 5, 12, 2, 3, 4, 51, 21, 4, 0, 0, 0, 0, ...\n", " \n", " \n", " 4\n", - " [0.4540216, 0.66014326, 0.4065639, 0.90007794,...\n", - " [0.5709135, 0.41235211, 0.21241243, 0.01835139...\n", - " [33, 29, 46, 15, 14, 27, 38, 115, 60, 122, 0, ...\n", - " [6, 6, 8, 3, 3, 5, 7, 20, 11, 21, 0, 0, 0, 0, ...\n", + " [0.87814647, 0.77221537, 0.20481698, 0.5081556...\n", + " [0.7751088, 0.00625885, 0.51996744, 0.73278934...\n", + " [12, 27, 2, 23, 5, 13, 26, 4, 19, 52, 0, 0, 0,...\n", + " [4, 7, 1, 8, 1, 5, 7, 2, 4, 13, 0, 0, 0, 0, 0,...\n", " \n", " \n", "\n", @@ -826,35 +1010,35 @@ ], "text/plain": [ " age_days-list \\\n", - "0 [0.9509504, 0.3658292, 0.10605793, 0.8901615, ... \n", - "1 [0.23776619, 0.062151734, 0.059320305, 0.37635... \n", - "2 [0.6510976, 0.002470178, 0.19554594, 0.6035013... \n", - "3 [0.62920743, 0.7574743, 0.1393074, 0.14867006,... \n", - "4 [0.4540216, 0.66014326, 0.4065639, 0.90007794,... \n", + "0 [0.9501408, 0.49153143, 0.43261203, 0.60583574... \n", + "1 [0.8527362, 0.12124015, 0.882591, 0.45782763, ... \n", + "2 [0.40404844, 0.59667224, 0.011392503, 0.209335... \n", + "3 [0.10275215, 0.5568824, 0.5089987, 0.14826113,... \n", + "4 [0.87814647, 0.77221537, 0.20481698, 0.5081556... \n", "\n", " weekday_sin-list \\\n", - "0 [0.9222485, 0.1284022, 0.92028487, 0.3788347, ... \n", - "1 [0.75332737, 0.18823138, 0.5440263, 0.27081072... \n", - "2 [0.0155129675, 0.067784436, 0.6556247, 0.90605... \n", - "3 [0.44066542, 0.6632927, 0.51982445, 0.8328001,... \n", - "4 [0.5709135, 0.41235211, 0.21241243, 0.01835139... \n", + "0 [0.50088716, 0.8589678, 0.25768423, 0.6750298,... \n", + "1 [0.47754368, 0.9635016, 0.2031468, 0.41193682,... \n", + "2 [0.9317083, 0.071744286, 0.958608, 0.24852555,... \n", + "3 [0.6575956, 0.71825516, 0.5113613, 0.3175862, ... \n", + "4 [0.7751088, 0.00625885, 0.51996744, 0.73278934... \n", "\n", " item_id-list \\\n", - "0 [14, 71, 45, 35, 140, 89, 7, 115, 196, 19, 2, ... \n", - "1 [6, 4, 42, 97, 208, 5, 50, 45, 7, 2, 0, 0, 0, ... \n", - "2 [25, 38, 126, 2, 14, 10, 8, 14, 16, 28, 0, 0, ... \n", - "3 [4, 12, 26, 19, 23, 124, 22, 2, 50, 38, 0, 0, ... \n", - "4 [33, 29, 46, 15, 14, 27, 38, 115, 60, 122, 0, ... \n", + "0 [8, 8, 10, 89, 130, 71, 16, 10, 9, 8, 8, 40, 0... \n", + "1 [5, 67, 25, 1, 30, 15, 10, 67, 38, 13, 19, 0, ... \n", + "2 [107, 13, 28, 18, 12, 91, 12, 2, 98, 1, 0, 0, ... \n", + "3 [8, 27, 10, 42, 6, 14, 12, 164, 77, 32, 0, 0, ... \n", + "4 [12, 27, 2, 23, 5, 13, 26, 4, 19, 52, 0, 0, 0,... \n", "\n", " category-list \n", - "0 [3, 14, 8, 6, 27, 16, 2, 20, 31, 4, 1, 3, 0, 0... \n", - "1 [1, 2, 7, 18, 34, 1, 9, 8, 2, 1, 0, 0, 0, 0, 0... \n", - "2 [5, 7, 21, 1, 3, 3, 1, 3, 3, 5, 0, 0, 0, 0, 0,... \n", - "3 [2, 3, 5, 4, 4, 22, 4, 1, 9, 7, 0, 0, 0, 0, 0,... \n", - "4 [6, 6, 8, 3, 3, 5, 7, 20, 11, 21, 0, 0, 0, 0, ... " + "0 [4, 4, 5, 27, 38, 19, 3, 5, 2, 4, 4, 10, 0, 0,... \n", + "1 [1, 17, 8, 1, 9, 3, 5, 17, 10, 5, 4, 0, 0, 0, ... \n", + "2 [33, 5, 7, 6, 4, 23, 4, 1, 26, 1, 0, 0, 0, 0, ... \n", + "3 [4, 7, 5, 12, 2, 3, 4, 51, 21, 4, 0, 0, 0, 0, ... \n", + "4 [4, 7, 1, 8, 1, 5, 7, 2, 4, 13, 0, 0, 0, 0, 0,... " ] }, - "execution_count": 22, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -882,7 +1066,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "id": "42091c25-7676-414e-bb8c-8432aeb58297", "metadata": {}, "outputs": [ @@ -890,19 +1074,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'next-item': array([[ -9.769284 , -3.3535378, -3.5593104, ..., -10.696345 ,\n", - " -9.082857 , -9.554779 ],\n", - " [ -9.769166 , -3.3535283, -3.5592926, ..., -10.696279 ,\n", - " -9.082819 , -9.55474 ],\n", - " [ -9.768643 , -3.3534937, -3.559177 , ..., -10.696127 ,\n", - " -9.0826 , -9.554597 ],\n", + "{'next-item': array([[ -9.942884 , -3.42258 , -3.4635031, ..., -9.813356 ,\n", + " -10.095871 , -9.584209 ],\n", + " [ -9.943485 , -3.4231985, -3.463529 , ..., -9.813207 ,\n", + " -10.095795 , -9.583893 ],\n", + " [ -9.94396 , -3.4236832, -3.463609 , ..., -9.813052 ,\n", + " -10.095815 , -9.583604 ],\n", " ...,\n", - " [ -9.769294 , -3.3535573, -3.559361 , ..., -10.696278 ,\n", - " -9.082909 , -9.554747 ],\n", - " [ -9.769636 , -3.3535905, -3.5594552, ..., -10.696384 ,\n", - " -9.083048 , -9.554836 ],\n", - " [ -9.769545 , -3.353582 , -3.5594208, ..., -10.696352 ,\n", - " -9.083025 , -9.554812 ]], dtype=float32)}\n" + " [ -9.942841 , -3.4225972, -3.4636211, ..., -9.813358 ,\n", + " -10.0957985, -9.584264 ],\n", + " [ -9.943155 , -3.4229283, -3.4636493, ..., -9.813265 ,\n", + " -10.095781 , -9.584093 ],\n", + " [ -9.943013 , -3.4227753, -3.4636147, ..., -9.813282 ,\n", + " -10.09578 , -9.584151 ]], dtype=float32)}\n" ] } ], @@ -914,17 +1098,17 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "id": "0fa425a4-9c00-45ed-a4b1-fd75ca4bf819", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(32, 503)" + "(32, 496)" ] }, - "execution_count": 24, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" }