diff --git a/element_deeplabcut/__init__.py b/element_deeplabcut/__init__.py index 80f05d0..7f2535c 100644 --- a/element_deeplabcut/__init__.py +++ b/element_deeplabcut/__init__.py @@ -18,4 +18,4 @@ "DLC_PROCESSED_DATA_DIR", dj.config["custom"].get("dlc_processed_data_dir", "") ) -db_prefix = dj.config["custom"].get("database.prefix", "") \ No newline at end of file +db_prefix = dj.config["custom"].get("database.prefix", "") diff --git a/notebooks/tutorial.ipynb b/notebooks/tutorial.ipynb index 825c12e..b63193c 100644 --- a/notebooks/tutorial.ipynb +++ b/notebooks/tutorial.ipynb @@ -120,9 +120,12 @@ "outputs": [], "source": [ "import os\n", - "if os.path.basename(os.getcwd())=='notebooks': os.chdir('..')\n", - "assert os.path.basename(os.getcwd())=='element-deeplabcut', (\"Please move to the \"\n", - " + \"element directory\")" + "\n", + "if os.path.basename(os.getcwd()) == \"notebooks\":\n", + " os.chdir(\"..\")\n", + "assert os.path.basename(os.getcwd()) == \"element-deeplabcut\", (\n", + " \"Please move to the \" + \"element directory\"\n", + ")" ] }, { @@ -201,7 +204,7 @@ } ], "source": [ - "from tutorial_pipeline import lab, subject, session, train, model " + "from tutorial_pipeline import lab, subject, session, train, model" ] }, { @@ -990,10 +993,10 @@ ], "source": [ "(\n", - " dj.Diagram(subject) \n", - " + dj.Diagram(lab) \n", - " + dj.Diagram(session) \n", - " + dj.Diagram(model) \n", + " dj.Diagram(subject)\n", + " + dj.Diagram(lab)\n", + " + dj.Diagram(session)\n", + " + dj.Diagram(model)\n", " + dj.Diagram(train)\n", ")" ] @@ -1274,7 +1277,9 @@ "metadata": {}, "outputs": [], "source": [ - "config_file_rel = \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\"" + "config_file_rel = (\n", + " \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/config.yaml\"\n", + ")" ] }, { @@ -1358,11 +1363,13 @@ } ], "source": [ - "model.Model.insert_new_model(model_name='from_top_tracking_model_test',\n", - " dlc_config=config_file_rel,\n", - " shuffle=1,\n", - " trainingsetindex=0,\n", - " model_description='Model in example data: from_top_tracking model')" + "model.Model.insert_new_model(\n", + " model_name=\"from_top_tracking_model_test\",\n", + " dlc_config=config_file_rel,\n", + " shuffle=1,\n", + " trainingsetindex=0,\n", + " model_description=\"Model in example data: from_top_tracking model\",\n", + ")" ] }, { @@ -1668,14 +1675,14 @@ "metadata": {}, "outputs": [], "source": [ - "#Definition of the dictionary named \"session_keys\"\n", + "# Definition of the dictionary named \"session_keys\"\n", "session_keys = [\n", " dict(subject=\"subject6\", session_datetime=\"2021-06-02 14:04:22\"),\n", " dict(subject=\"subject6\", session_datetime=\"2021-06-03 14:43:10\"),\n", "]\n", "\n", - "#Insert this dictionary in the Session table\n", - "session.Session.insert(session_keys, skip_duplicates=True)\n" + "# Insert this dictionary in the Session table\n", + "session.Session.insert(session_keys, skip_duplicates=True)" ] }, { @@ -1791,10 +1798,14 @@ "metadata": {}, "outputs": [], "source": [ - "recording_key = {'subject': 'subject6',\n", - " 'session_datetime': '2021-06-02 14:04:22',\n", - " 'recording_id': '1'}\n", - "model.VideoRecording.insert1({**recording_key, 'device': 'Camera1'}, skip_duplicates=True)" + "recording_key = {\n", + " \"subject\": \"subject6\",\n", + " \"session_datetime\": \"2021-06-02 14:04:22\",\n", + " \"recording_id\": \"1\",\n", + "}\n", + "model.VideoRecording.insert1(\n", + " {**recording_key, \"device\": \"Camera1\"}, skip_duplicates=True\n", + ")" ] }, { @@ -1810,12 +1821,14 @@ "metadata": {}, "outputs": [], "source": [ - "video_files = [\"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"]\n", + "video_files = [\n", + " \"./example_data/inbox/from_top_tracking-DataJoint-2023-10-11/videos/train1.mp4\"\n", + "]\n", "\n", - "model.VideoRecording.File.insert({\n", - " **recording_key, \n", - " 'file_id': v_idx, \n", - " 'file_path': Path(f)} for v_idx, f in enumerate(video_files))" + "model.VideoRecording.File.insert(\n", + " {**recording_key, \"file_id\": v_idx, \"file_path\": Path(f)}\n", + " for v_idx, f in enumerate(video_files)\n", + ")" ] }, { @@ -2054,7 +2067,7 @@ "metadata": {}, "outputs": [], "source": [ - "task_key = {**recording_key, 'model_name': 'from_top_tracking_model_test'}" + "task_key = {**recording_key, \"model_name\": \"from_top_tracking_model_test\"}" ] }, { @@ -2071,10 +2084,12 @@ "outputs": [], "source": [ "model.PoseEstimationTask.insert1(\n", - " {**task_key,\n", - " 'task_mode': 'load',\n", - " 'pose_estimation_output_dir': './example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters'\n", - " })" + " {\n", + " **task_key,\n", + " \"task_mode\": \"load\",\n", + " \"pose_estimation_output_dir\": \"./example_data/outbox/from_top_tracking-DataJoint-2023-10-11/videos/device_1_recording_1_model_from_top_tracking_100000_maxiters\",\n", + " }\n", + ")" ] }, { @@ -2471,7 +2486,11 @@ "metadata": {}, "outputs": [], "source": [ - "df = (model.PoseEstimation.BodyPartPosition & task_key).fetch(format='frame').reset_index()" + "df = (\n", + " (model.PoseEstimation.BodyPartPosition & task_key)\n", + " .fetch(format=\"frame\")\n", + " .reset_index()\n", + ")" ] }, { @@ -2836,7 +2855,7 @@ } ], "source": [ - "df = df.explode(['frame_index', 'x_pos', 'y_pos', 'likelihood']).reset_index()\n", + "df = df.explode([\"frame_index\", \"x_pos\", \"y_pos\", \"likelihood\"]).reset_index()\n", "df" ] }, @@ -2871,8 +2890,8 @@ "source": [ "import matplotlib.pyplot as plt\n", "\n", - "head_data = df[df['body_part'] == 'head']\n", - "tail_data = df[df['body_part'] == 'tailbase']" + "head_data = df[df[\"body_part\"] == \"head\"]\n", + "tail_data = df[df[\"body_part\"] == \"tailbase\"]" ] }, { @@ -2892,18 +2911,18 @@ } ], "source": [ - "fig, axs = plt.subplots(2,1, figsize=(12, 4))\n", + "fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n", "\n", - "axs[0].set_title('x position - Head pose estimation')\n", - "axs[0].plot(head_data['x_pos'], label='x_pos')\n", - "axs[0].set_xlabel('time (frames)')\n", - "axs[0].set_ylabel('pos (pixels)')\n", + "axs[0].set_title(\"x position - Head pose estimation\")\n", + "axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\")\n", + "axs[0].set_xlabel(\"time (frames)\")\n", + "axs[0].set_ylabel(\"pos (pixels)\")\n", "axs[0].legend()\n", "\n", - "axs[1].set_title('y position - Head pose estimation')\n", - "axs[1].plot(head_data['y_pos'], label='y_pos')\n", - "axs[1].set_xlabel('time (frames)')\n", - "axs[1].set_ylabel('pos (pixels)')\n", + "axs[1].set_title(\"y position - Head pose estimation\")\n", + "axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\")\n", + "axs[1].set_xlabel(\"time (frames)\")\n", + "axs[1].set_ylabel(\"pos (pixels)\")\n", "axs[1].legend()\n", "\n", "plt.tight_layout()\n", @@ -2927,17 +2946,17 @@ } ], "source": [ - "fig, axs = plt.subplots(2,1, figsize=(12, 4))\n", - "axs[0].set_title('x position - Tailbase pose estimation')\n", - "axs[0].plot(head_data['x_pos'], label='x_pos',color='orange')\n", - "axs[0].set_xlabel('time (frames)')\n", - "axs[0].set_ylabel('pos (pixels)')\n", + "fig, axs = plt.subplots(2, 1, figsize=(12, 4))\n", + "axs[0].set_title(\"x position - Tailbase pose estimation\")\n", + "axs[0].plot(head_data[\"x_pos\"], label=\"x_pos\", color=\"orange\")\n", + "axs[0].set_xlabel(\"time (frames)\")\n", + "axs[0].set_ylabel(\"pos (pixels)\")\n", "axs[0].legend()\n", "\n", - "axs[1].set_title('y position - Tailbase pose estimation')\n", - "axs[1].plot(head_data['y_pos'], label='y_pos',color='orange')\n", - "axs[1].set_xlabel('time (frames)')\n", - "axs[1].set_ylabel('pos (pixels)')\n", + "axs[1].set_title(\"y position - Tailbase pose estimation\")\n", + "axs[1].plot(head_data[\"y_pos\"], label=\"y_pos\", color=\"orange\")\n", + "axs[1].set_xlabel(\"time (frames)\")\n", + "axs[1].set_ylabel(\"pos (pixels)\")\n", "axs[1].legend()\n", "\n", "plt.tight_layout()\n", @@ -2968,18 +2987,18 @@ } ], "source": [ - "fig, axs = plt.subplots(2,1, figsize=(6,10))\n", + "fig, axs = plt.subplots(2, 1, figsize=(6, 10))\n", "\n", - "axs[0].set_title('Head pose estimation')\n", - "axs[0].plot(head_data['x_pos'], head_data['y_pos'],label='head',color='blue')\n", - "axs[0].set_xlabel('x position (pixels)')\n", - "axs[0].set_ylabel('y position (pixels)')\n", + "axs[0].set_title(\"Head pose estimation\")\n", + "axs[0].plot(head_data[\"x_pos\"], head_data[\"y_pos\"], label=\"head\", color=\"blue\")\n", + "axs[0].set_xlabel(\"x position (pixels)\")\n", + "axs[0].set_ylabel(\"y position (pixels)\")\n", "axs[0].legend()\n", "\n", - "axs[1].set_title('Tailbase pose estimation')\n", - "axs[1].plot(tail_data['x_pos'], tail_data['y_pos'], label='tailbase',color='orange')\n", - "axs[1].set_xlabel('x position (pixels)')\n", - "axs[1].set_ylabel('y position (pixels)')\n", + "axs[1].set_title(\"Tailbase pose estimation\")\n", + "axs[1].plot(tail_data[\"x_pos\"], tail_data[\"y_pos\"], label=\"tailbase\", color=\"orange\")\n", + "axs[1].set_xlabel(\"x position (pixels)\")\n", + "axs[1].set_ylabel(\"y position (pixels)\")\n", "axs[1].legend()\n", "\n", "plt.tight_layout()\n",