diff --git a/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb b/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb index 18a4affd1..6af5c85e6 100644 --- a/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb +++ b/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb @@ -205,7 +205,7 @@ " \n", " 0\n", " item_id-list\n", - " (Tags.CATEGORICAL, Tags.LIST, Tags.ID, Tags.ITEM)\n", + " (Tags.CATEGORICAL, Tags.ITEM, Tags.ID, Tags.LIST)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", @@ -283,7 +283,7 @@ "" ], "text/plain": [ - "[{'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 52742.0, 'dimension': 512.0}, 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 337.0, 'dimension': 42.0}, 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}]" + "[{'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 52742.0, 'dimension': 512.0}, 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 337.0, 'dimension': 42.0}, 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}]" ] }, "execution_count": 5, @@ -531,7 +531,7 @@ "
\n", " \n", " \n", - " [560/560 00:25, Epoch 10/10]\n", + " [560/560 00:33, Epoch 10/10]\n", "
\n", " \n", " \n", @@ -543,11 +543,11 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
2007.5856007.596700
4006.6087006.609100

" @@ -579,7 +579,7 @@ "

\n", " \n", " \n", - " [11/11 00:33]\n", + " [11/11 00:45]\n", "
\n", " " ], @@ -610,12 +610,12 @@ "\n", "***** Evaluation results for day 179:*****\n", "\n", - " eval_/next-item/avg_precision@10 = 0.07277625054121017\n", - " eval_/next-item/avg_precision@20 = 0.077287457883358\n", - " eval_/next-item/ndcg@10 = 0.1008271649479866\n", - " eval_/next-item/ndcg@20 = 0.11763089150190353\n", - " eval_/next-item/recall@10 = 0.18959537148475647\n", - " eval_/next-item/recall@20 = 0.25549131631851196\n", + " eval_/next-item/avg_precision@10 = 0.07518380880355835\n", + " eval_/next-item/avg_precision@20 = 0.07953299582004547\n", + " eval_/next-item/ndcg@10 = 0.10182357579469681\n", + " eval_/next-item/ndcg@20 = 0.11847736686468124\n", + " eval_/next-item/recall@10 = 0.1872832328081131\n", + " eval_/next-item/recall@20 = 0.2520231306552887\n", "\n", "***** Launch training for day 179: *****\n" ] @@ -627,7 +627,7 @@ "
\n", " \n", " \n", - " [400/400 00:17, Epoch 10/10]\n", + " [400/400 00:23, Epoch 10/10]\n", "
\n", " \n", " \n", @@ -639,11 +639,11 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
2006.8384006.858000
4006.3046006.317900

" @@ -680,12 +680,12 @@ "\n", "***** Evaluation results for day 180:*****\n", "\n", - " eval_/next-item/avg_precision@10 = 0.059328265488147736\n", - " eval_/next-item/avg_precision@20 = 0.06352042406797409\n", - " eval_/next-item/ndcg@10 = 0.08318208903074265\n", - " eval_/next-item/ndcg@20 = 0.09845318645238876\n", - " eval_/next-item/recall@10 = 0.16083915531635284\n", - " eval_/next-item/recall@20 = 0.2209790199995041\n", + " eval_/next-item/avg_precision@10 = 0.060561101883649826\n", + " eval_/next-item/avg_precision@20 = 0.06499475985765457\n", + " eval_/next-item/ndcg@10 = 0.08409972488880157\n", + " eval_/next-item/ndcg@20 = 0.10058131068944931\n", + " eval_/next-item/recall@10 = 0.1594405621290207\n", + " eval_/next-item/recall@20 = 0.22517482936382294\n", "\n", "***** Launch training for day 180: *****\n" ] @@ -697,7 +697,7 @@ "

\n", " \n", " \n", - " [330/330 00:14, Epoch 10/10]\n", + " [330/330 00:19, Epoch 10/10]\n", "
\n", " \n", " \n", @@ -709,7 +709,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
2006.7080006.712100

" @@ -739,12 +739,12 @@ "\n", "***** Evaluation results for day 181:*****\n", "\n", - " eval_/next-item/avg_precision@10 = 0.12736327946186066\n", - " eval_/next-item/avg_precision@20 = 0.13500627875328064\n", - " eval_/next-item/ndcg@10 = 0.16738776862621307\n", - " eval_/next-item/ndcg@20 = 0.19680777192115784\n", - " eval_/next-item/recall@10 = 0.29406309127807617\n", - " eval_/next-item/recall@20 = 0.41187384724617004\n" + " eval_/next-item/avg_precision@10 = 0.12142563611268997\n", + " eval_/next-item/avg_precision@20 = 0.128932923078537\n", + " eval_/next-item/ndcg@10 = 0.16741396486759186\n", + " eval_/next-item/ndcg@20 = 0.19582374393939972\n", + " eval_/next-item/recall@10 = 0.30705010890960693\n", + " eval_/next-item/recall@20 = 0.41929498314857483\n" ] } ], @@ -784,24 +784,24 @@ { "data": { "text/plain": [ - "{'indexed_by_time_eval_/next-item/avg_precision@10': [0.07277625054121017,\n", - " 0.059328265488147736,\n", - " 0.12736327946186066],\n", - " 'indexed_by_time_eval_/next-item/avg_precision@20': [0.077287457883358,\n", - " 0.06352042406797409,\n", - " 0.13500627875328064],\n", - " 'indexed_by_time_eval_/next-item/ndcg@10': [0.1008271649479866,\n", - " 0.08318208903074265,\n", - " 0.16738776862621307],\n", - " 'indexed_by_time_eval_/next-item/ndcg@20': [0.11763089150190353,\n", - " 0.09845318645238876,\n", - " 0.19680777192115784],\n", - " 'indexed_by_time_eval_/next-item/recall@10': [0.18959537148475647,\n", - " 0.16083915531635284,\n", - " 0.29406309127807617],\n", - " 'indexed_by_time_eval_/next-item/recall@20': [0.25549131631851196,\n", - " 0.2209790199995041,\n", - " 0.41187384724617004]}" + "{'indexed_by_time_eval_/next-item/avg_precision@10': [0.07518380880355835,\n", + " 0.060561101883649826,\n", + " 0.12142563611268997],\n", + " 'indexed_by_time_eval_/next-item/avg_precision@20': [0.07953299582004547,\n", + " 0.06499475985765457,\n", + " 0.128932923078537],\n", + " 'indexed_by_time_eval_/next-item/ndcg@10': [0.10182357579469681,\n", + " 0.08409972488880157,\n", + " 0.16741396486759186],\n", + " 'indexed_by_time_eval_/next-item/ndcg@20': [0.11847736686468124,\n", + " 0.10058131068944931,\n", + " 0.19582374393939972],\n", + " 'indexed_by_time_eval_/next-item/recall@10': [0.1872832328081131,\n", + " 0.1594405621290207,\n", + " 0.30705010890960693],\n", + " 'indexed_by_time_eval_/next-item/recall@20': [0.2520231306552887,\n", + " 0.22517482936382294,\n", + " 0.41929498314857483]}" ] }, "execution_count": 11, @@ -825,12 +825,12 @@ "name": "stdout", "output_type": "stream", "text": [ - " indexed_by_time_eval_/next-item/avg_precision@10 = 0.08648926516373952\n", - " indexed_by_time_eval_/next-item/avg_precision@20 = 0.09193805356820424\n", - " indexed_by_time_eval_/next-item/ndcg@10 = 0.1171323408683141\n", - " indexed_by_time_eval_/next-item/ndcg@20 = 0.13763061662515005\n", - " indexed_by_time_eval_/next-item/recall@10 = 0.21483253935972849\n", - " indexed_by_time_eval_/next-item/recall@20 = 0.2961147278547287\n" + " indexed_by_time_eval_/next-item/avg_precision@10 = 0.08572351559996605\n", + " indexed_by_time_eval_/next-item/avg_precision@20 = 0.09115355958541234\n", + " indexed_by_time_eval_/next-item/ndcg@10 = 0.11777908851703008\n", + " indexed_by_time_eval_/next-item/ndcg@20 = 0.13829414049784342\n", + " indexed_by_time_eval_/next-item/recall@10 = 0.21792463461558023\n", + " indexed_by_time_eval_/next-item/recall@20 = 0.2988309810558955\n" ] } ], @@ -1105,20 +1105,7 @@ "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'session_id', which is not being used by any downstream operator in the ensemble graph.\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'item_id-count', which is not being used by any downstream operator in the ensemble graph.\n", - " warnings.warn(\n", - "/usr/local/lib/python3.8/dist-packages/merlin/systems/dag/node.py:100: UserWarning: Operator 'TransformWorkflow' is producing the output column 'day_index', which is not being used by any downstream operator in the ensemble graph.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "ensemble = Ensemble(torch_op, workflow.input_schema)\n", "ens_config, node_configs = ensemble.export(ens_model_path)" @@ -1164,7 +1151,7 @@ "\n", "`tritonserver --model-repository=`\n", "\n", - "For the `--model-repository` argument, specify the same path as the export_path that you specified previously in the `ensemble.export` method. This command will launch the server and load all the models to the server. Once all the models are loaded successfully, you should see READY status printed out in the terminal for each loaded model." + "For the `--model-repository` argument, specify the same path as the `ens_model_path` that you specified previously in the `ensemble.export` method. This command will launch the server and load all the models to the server. Once all the models are loaded successfully, you should see READY status printed out in the terminal for each loaded model." ] }, { @@ -1244,7 +1231,8 @@ "interactions_merged_df = interactions_merged_df.sort_values('timestamp')\n", "batch = interactions_merged_df[-50:]\n", "sessions_to_use = batch.session_id.value_counts()\n", - "filtered_batch = batch[batch.session_id.isin(sessions_to_use[sessions_to_use.values>1].index.values)]" + "filtered_batch = batch[batch.session_id.isin(sessions_to_use[sessions_to_use.values>1].index.values)]\n", + "filtered_batch = filtered_batch.dropna()" ] }, { @@ -1305,18 +1293,22 @@ "metadata": {}, "outputs": [ { - "ename": "ValueError", - "evalue": "cannot convert NA to integer", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[26], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmerlin\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msystems\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtriton\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m send_triton_request\n\u001b[0;32m----> 2\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43msend_triton_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mworkflow\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_schema\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfiltered_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput_schema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumn_names\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(response)\n", - "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/merlin/systems/triton/utils.py:226\u001b[0m, in \u001b[0;36msend_triton_request\u001b[0;34m(schema, inputs, outputs_list, client, endpoint, request_id, triton_model)\u001b[0m\n\u001b[1;32m 224\u001b[0m triton_inputs \u001b[38;5;241m=\u001b[39m triton\u001b[38;5;241m.\u001b[39mconvert_table_to_triton_input(schema, inputs, grpcclient\u001b[38;5;241m.\u001b[39mInferInput)\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 226\u001b[0m triton_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mtriton\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconvert_df_to_triton_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrpcclient\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferInput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 228\u001b[0m outputs \u001b[38;5;241m=\u001b[39m [grpcclient\u001b[38;5;241m.\u001b[39mInferRequestedOutput(col) \u001b[38;5;28;01mfor\u001b[39;00m col \u001b[38;5;129;01min\u001b[39;00m outputs_list]\n\u001b[1;32m 230\u001b[0m response \u001b[38;5;241m=\u001b[39m client\u001b[38;5;241m.\u001b[39minfer(triton_model, triton_inputs, request_id\u001b[38;5;241m=\u001b[39mrequest_id, outputs\u001b[38;5;241m=\u001b[39moutputs)\n", - "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/merlin/systems/triton/__init__.py:88\u001b[0m, in \u001b[0;36mconvert_df_to_triton_input\u001b[0;34m(schema, batch, input_class, dtype)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconvert_df_to_triton_input\u001b[39m(schema, batch, input_class\u001b[38;5;241m=\u001b[39mgrpcclient\u001b[38;5;241m.\u001b[39mInferInput, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 69\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;124;03m Convert a dataframe to a set of Triton inputs\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;124;03m A list of Triton inputs of the requested input class\u001b[39;00m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 88\u001b[0m df_dict \u001b[38;5;241m=\u001b[39m \u001b[43m_convert_df_to_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 89\u001b[0m inputs \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 90\u001b[0m _convert_array_to_triton_input(col_name, col_values, input_class)\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m col_name, col_values \u001b[38;5;129;01min\u001b[39;00m df_dict\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 92\u001b[0m ]\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m inputs\n", - "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/merlin/systems/triton/__init__.py:183\u001b[0m, in \u001b[0;36m_convert_df_to_dict\u001b[0;34m(schema, batch, dtype)\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 182\u001b[0m values \u001b[38;5;241m=\u001b[39m col\u001b[38;5;241m.\u001b[39mvalues \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(col, pd\u001b[38;5;241m.\u001b[39mSeries) \u001b[38;5;28;01melse\u001b[39;00m col\u001b[38;5;241m.\u001b[39mvalues_host\n\u001b[0;32m--> 183\u001b[0m values \u001b[38;5;241m=\u001b[39m \u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcol_schema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_numpy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 184\u001b[0m df_dict[col_name] \u001b[38;5;241m=\u001b[39m values\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m df_dict\n", - "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/arrays/masked.py:471\u001b[0m, in \u001b[0;36mBaseMaskedArray.astype\u001b[0;34m(self, dtype, copy)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;66;03m# to_numpy will also raise, but we get somewhat nicer exception messages here\u001b[39;00m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer_dtype(dtype) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_hasna:\n\u001b[0;32m--> 471\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot convert NA to integer\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 472\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_bool_dtype(dtype) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_hasna:\n\u001b[1;32m 473\u001b[0m \u001b[38;5;66;03m# careful: astype_nansafe converts np.nan to True\u001b[39;00m\n\u001b[1;32m 474\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot convert float NaN to bool\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "\u001b[0;31mValueError\u001b[0m: cannot convert NA to integer" + "name": "stdout", + "output_type": "stream", + "text": [ + "{'item_id_scores': array([[6.753515 , 6.7321076, 6.550942 , 6.530278 , 6.524622 , 6.4839935,\n", + " 6.4497213, 6.2553434, 6.2440977, 6.0922346, 6.077484 , 6.051918 ,\n", + " 6.0381503, 5.9162364, 5.913348 , 5.9052362, 5.813846 , 5.7530656,\n", + " 5.6659293, 5.528434 ],\n", + " [6.2170486, 5.98991 , 5.9599447, 5.653755 , 5.6248474, 5.5384746,\n", + " 5.4608607, 5.401359 , 5.3743906, 5.370715 , 5.3688684, 5.338537 ,\n", + " 5.243487 , 5.2261944, 5.22056 , 5.2028203, 5.14725 , 5.146728 ,\n", + " 5.137348 , 5.136954 ]], dtype=float32), 'item_ids': array([[ 113, 109, 69, 302, 542, 534, 478, 260, 417,\n", + " 141, 199, 6102, 259, 86, 4453, 103, 482, 414,\n", + " 197, 5730],\n", + " [ 282, 8042, 6740, 31054, 10928, 7022, 3052, 10743, 5873,\n", + " 1017, 9413, 8298, 6967, 1922, 2707, 4432, 15535, 6164,\n", + " 11396, 9228]])}\n" ] } ], @@ -1372,7 +1364,7 @@ { "data": { "text/plain": [ - "(15, 20)" + "(2, 20)" ] }, "execution_count": 28,