diff --git a/child_lab_framework/demo_sequential.py b/child_lab_framework/demo_sequential.py index 3f855c7..9646308 100644 --- a/child_lab_framework/demo_sequential.py +++ b/child_lab_framework/demo_sequential.py @@ -5,7 +5,7 @@ from .core.video import Format, Perspective, Reader, Writer from .logging import Logger -from .task import depth, face, gaze, pose +from .task import depth, face, gaze, pose, emotion from .task.camera import transformation from .task.visualization import Visualizer @@ -62,6 +62,9 @@ def main() -> None: threshold=0.1, ) + emotions_estimator_left = emotion.Estimator(executor) + emotions_estimator_right = emotion.Estimator(executor) + window_left_gaze_estimator = gaze.Estimator( executor, input=window_left_reader.properties, @@ -106,6 +109,8 @@ def main() -> None: output_format=Format.MP4, ) + print('Starting sequential processing') + while True: ceiling_frames = ceiling_reader.read_batch() if ceiling_frames is None: @@ -205,11 +210,19 @@ def main() -> None: else None ) + window_left_emotions = emotions_estimator_left.predict_batch( + window_left_frames, window_left_faces + ) + window_right_emotions = emotions_estimator_right.predict_batch( + window_right_frames, window_right_faces + ) + ceiling_annotated_frames = visualizer.annotate_batch( ceiling_frames, ceiling_poses, None, ceiling_gazes, + None, ) window_left_annotated_frames = visualizer.annotate_batch( @@ -217,6 +230,7 @@ def main() -> None: window_left_poses, window_left_faces, None, + window_left_emotions, ) window_right_annotated_frames = visualizer.annotate_batch( @@ -224,6 +238,7 @@ def main() -> None: window_right_poses, window_right_faces, None, + window_right_emotions, ) ceiling_writer.write_batch(ceiling_annotated_frames) diff --git a/child_lab_framework/task/emotion/__init__.py b/child_lab_framework/task/emotion/__init__.py new file mode 100644 index 0000000..3c8fd40 --- /dev/null +++ b/child_lab_framework/task/emotion/__init__.py @@ -0,0 +1,3 @@ +from .emotion import Estimator, Result + +__all__ = ['Estimator', 'Result'] diff --git a/child_lab_framework/task/emotion/emotion.py b/child_lab_framework/task/emotion/emotion.py new file mode 100644 index 0000000..6391b08 --- /dev/null +++ b/child_lab_framework/task/emotion/emotion.py @@ -0,0 +1,109 @@ +import asyncio +from deepface import DeepFace +from concurrent.futures import ThreadPoolExecutor +from itertools import repeat, starmap + +from ...task import face +from ...core.sequence import imputed_with_reference_inplace +from ...core.video import Frame +from ...typing.stream import Fiber +from ...typing.array import FloatArray2 + +type Input = tuple[ + list[Frame | None] | None, + list[face.Result | None] | None, +] + + +class Result: + emotions: list[float] + boxes: list[FloatArray2] + + def __init__(self, emotions: list[float], boxes: list[FloatArray2]) -> None: + self.emotions = emotions + self.boxes = boxes + + +class Estimator: + executor: ThreadPoolExecutor + + def __init__(self, executor: ThreadPoolExecutor) -> None: + self.executor = executor + + def __predict(self, frame: Frame, faces: face.Result) -> Result: + face_emotions = [] + boxes = [] + frame_height, frame_width, _ = frame.shape + for face_box in faces.boxes: + x_min, y_min, x_max, y_max = face_box + x_min = max(x_min - 50, 0) + x_max = min(x_max + 50, frame_width) + y_min = max(y_min - 50, 0) + y_max = min(y_max + 50, frame_height) + cropped_frame = frame[y_min:y_max, x_min:x_max] + analysis = DeepFace.analyze( + cropped_frame, actions=['emotion'], enforce_detection=False + ) + emotion = self.__score(analysis[0]) + face_emotions.append(emotion) + boxes.append(face_box) + + return Result(face_emotions, boxes) + + def __predict_safe(self, frame: Frame, faces: face.Result | None) -> Result: + if faces is None: + return Result([], []) + return self.__predict(frame, faces) + + def predict_batch( + self, + frames: list[Frame], + faces: list[face.Result | None], + ) -> list[Result] | None: + return imputed_with_reference_inplace( + list(starmap(self.__predict, zip(frames, faces))) + ) + + async def stream( + self, + ) -> Fiber[list[Input] | None, list[Result | None] | None]: + loop = asyncio.get_running_loop() + executor = self.executor + + results: list[Result | None] | None = None + + while True: + match (yield results): + case ( + list(frames), + faces, + ): + results = await loop.run_in_executor( + executor, + lambda: list( + starmap( + self.__predict, + zip(frames, faces or repeat(None)), + ) + ), + ) + + case _: + results = None + + def __score(emotions: list[dict[str, float]]) -> list[float]: + # Most of the time, "angry" and "fear" are similar to "neutral" in the reality + scores = { + 'angry': -0.05, + 'disgust': 0, + 'fear': -0.07, + 'happy': 1, + 'sad': -1, + 'surprise': 0, + 'neutral': 0, + } + val = 0 + for emotion, score in scores.items(): + val += emotions['emotion'][emotion] * score + + return val \ No newline at end of file diff --git a/child_lab_framework/task/visualization/visualization.py b/child_lab_framework/task/visualization/visualization.py index 094090e..e9c6874 100644 --- a/child_lab_framework/task/visualization/visualization.py +++ b/child_lab_framework/task/visualization/visualization.py @@ -3,12 +3,13 @@ from itertools import repeat, starmap import cv2 +import cv2.text import numpy as np from ...core.video import Frame, Properties from ...typing.array import FloatArray1, FloatArray2, IntArray1 from ...typing.stream import Fiber -from .. import face, pose +from .. import face, pose, emotion from ..gaze import ceiling_projection from ..pose.keypoint import YOLO_SKELETON @@ -123,12 +124,27 @@ def __draw_face_box(self, frame: Frame, result: face.Result) -> Frame: return frame + def __draw_emotions_text(self, frame: Frame, result: emotion.Result) -> Frame: + color = self.FACE_BOUNDING_BOX_COLOR + for value, box in zip(result.emotions, result.boxes): + cv2.putText( + frame, + str(value), + [box[0], box[3]], + cv2.FONT_HERSHEY_SIMPLEX, + 0.9, + color, + 2, + ) + return frame + def __annotate_safe( self, frame: Frame, poses: pose.Result | None, faces: face.Result | None, gazes: ceiling_projection.Result | None, + emotions: emotion.Result | None, ) -> Frame: out = frame.copy() out.flags.writeable = True @@ -143,6 +159,9 @@ def __annotate_safe( if gazes is not None: out = self.__draw_gaze_estimation(out, gazes) + if emotions is not None: + out = self.__draw_emotions_text(out, emotions) + return out def annotate_batch( @@ -151,6 +170,7 @@ def annotate_batch( poses: list[pose.Result] | None, faces: list[face.Result] | None, gazes: list[ceiling_projection.Result] | None, + emotions: list[emotion.Result] | None, ) -> list[Frame]: return list( starmap( @@ -160,6 +180,7 @@ def annotate_batch( poses or repeat(None), faces or repeat(None), gazes or repeat(None), + emotions or repeat(None), ), ) ) diff --git a/poetry.lock b/poetry.lock index ad60eaa..f9155e4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -63,6 +63,38 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "beautifulsoup4" +version = "4.12.3" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, + {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, +] + +[package.dependencies] +soupsieve = ">1.2" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + +[[package]] +name = "blinker" +version = "1.8.2" +description = "Fast, simple object-to-object and broadcast signaling" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blinker-1.8.2-py3-none-any.whl", hash = "sha256:1779309f71bf239144b9399d06ae925637cf6634cf6bd131104184531bf67c01"}, + {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -267,6 +299,20 @@ files = [ {file = "charset_normalizer-3.4.0.tar.gz", hash = "sha256:223217c3d4f82c3ac5e29032b3f1c2eb0fb591b72161f86d93f5719079dae93e"}, ] +[[package]] +name = "click" +version = "8.1.7" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"}, + {file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.6" @@ -377,6 +423,34 @@ files = [ docs = ["ipython", "matplotlib", "numpydoc", "sphinx"] tests = ["pytest", "pytest-cov", "pytest-xdist"] +[[package]] +name = "deepface" +version = "0.0.93" +description = "A Lightweight Face Recognition and Facial Attribute Analysis Framework (Age, Gender, Emotion, Race) for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "deepface-0.0.93-py3-none-any.whl", hash = "sha256:27043e1aa5df05a060bcfe1743409075c66f6f1de86a592ba0cdac79ac9e7987"}, + {file = "deepface-0.0.93.tar.gz", hash = "sha256:7f5fc6306a3a07ee6c529b03571e64fe53d9f259e1d4091f5e28386264962b92"}, +] + +[package.dependencies] +fire = ">=0.4.0" +Flask = ">=1.1.2" +flask-cors = ">=4.0.1" +gdown = ">=3.10.1" +gunicorn = ">=20.1.0" +keras = ">=2.2.0" +mtcnn = ">=0.1.0" +numpy = ">=1.14.0" +opencv-python = ">=4.5.5.64" +pandas = ">=0.23.4" +Pillow = ">=5.2.0" +requests = ">=2.27.1" +retina-face = ">=0.0.1" +tensorflow = ">=1.9.0" +tqdm = ">=4.30.0" + [[package]] name = "depth-pro" version = "0.1.0" @@ -450,6 +524,55 @@ docs = ["furo (>=2024.8.6)", "sphinx (>=8.0.2)", "sphinx-autodoc-typehints (>=2. testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "diff-cover (>=9.2)", "pytest (>=8.3.3)", "pytest-asyncio (>=0.24)", "pytest-cov (>=5)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.26.4)"] typing = ["typing-extensions (>=4.12.2)"] +[[package]] +name = "fire" +version = "0.7.0" +description = "A library for automatically generating command line interfaces." +optional = false +python-versions = "*" +files = [ + {file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"}, +] + +[package.dependencies] +termcolor = "*" + +[[package]] +name = "flask" +version = "3.0.3" +description = "A simple framework for building complex web applications." +optional = false +python-versions = ">=3.8" +files = [ + {file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"}, + {file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"}, +] + +[package.dependencies] +blinker = ">=1.6.2" +click = ">=8.1.3" +itsdangerous = ">=2.1.2" +Jinja2 = ">=3.1.2" +Werkzeug = ">=3.0.0" + +[package.extras] +async = ["asgiref (>=3.2)"] +dotenv = ["python-dotenv"] + +[[package]] +name = "flask-cors" +version = "5.0.0" +description = "A Flask extension adding a decorator for CORS support" +optional = false +python-versions = "*" +files = [ + {file = "Flask_Cors-5.0.0-py2.py3-none-any.whl", hash = "sha256:b9e307d082a9261c100d8fb0ba909eec6a228ed1b60a8315fd85f783d61910bc"}, + {file = "flask_cors-5.0.0.tar.gz", hash = "sha256:5aadb4b950c4e93745034594d9f3ea6591f734bb3662e16e255ffbf5e89c88ef"}, +] + +[package.dependencies] +Flask = ">=0.9" + [[package]] name = "flatbuffers" version = "24.3.25" @@ -582,6 +705,26 @@ files = [ {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"}, ] +[[package]] +name = "gdown" +version = "5.2.0" +description = "Google Drive Public File/Folder Downloader" +optional = false +python-versions = ">=3.8" +files = [ + {file = "gdown-5.2.0-py3-none-any.whl", hash = "sha256:33083832d82b1101bdd0e9df3edd0fbc0e1c5f14c9d8c38d2a35bf1683b526d6"}, + {file = "gdown-5.2.0.tar.gz", hash = "sha256:2145165062d85520a3cd98b356c9ed522c5e7984d408535409fd46f94defc787"}, +] + +[package.dependencies] +beautifulsoup4 = "*" +filelock = "*" +requests = {version = "*", extras = ["socks"]} +tqdm = "*" + +[package.extras] +test = ["build", "mypy", "pytest", "pytest-xdist", "ruff", "twine", "types-requests", "types-setuptools"] + [[package]] name = "google-pasta" version = "0.2.0" @@ -664,6 +807,27 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.67.0)"] +[[package]] +name = "gunicorn" +version = "23.0.0" +description = "WSGI HTTP Server for UNIX" +optional = false +python-versions = ">=3.7" +files = [ + {file = "gunicorn-23.0.0-py3-none-any.whl", hash = "sha256:ec400d38950de4dfd418cff8328b2c8faed0edb0d517d3394e457c317908ca4d"}, + {file = "gunicorn-23.0.0.tar.gz", hash = "sha256:f014447a0101dc57e294f6c18ca6b40227a4c90e9bdb586042628030cba004ec"}, +] + +[package.dependencies] +packaging = "*" + +[package.extras] +eventlet = ["eventlet (>=0.24.1,!=0.36.0)"] +gevent = ["gevent (>=1.4.0)"] +setproctitle = ["setproctitle"] +testing = ["coverage", "eventlet", "gevent", "pytest", "pytest-cov"] +tornado = ["tornado (>=0.2)"] + [[package]] name = "h5py" version = "3.12.1" @@ -767,6 +931,17 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "itsdangerous" +version = "2.2.0" +description = "Safely pass data to untrusted environments and back." +optional = false +python-versions = ">=3.8" +files = [ + {file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"}, + {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, +] + [[package]] name = "jax" version = "0.4.35" @@ -2138,6 +2313,18 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] +[[package]] +name = "pysocks" +version = "1.7.1" +description = "A Python SOCKS client module. See https://github.com/Anorov/PySocks for more information." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "PySocks-1.7.1-py27-none-any.whl", hash = "sha256:08e69f092cc6dbe92a0fdd16eeb9b9ffbc13cadfe5ca4c7bd92ffb078b293299"}, + {file = "PySocks-1.7.1-py3-none-any.whl", hash = "sha256:2725bd0a9925919b9b51739eea5f9e2bae91e83288108a9ad338b2e3a4435ee5"}, + {file = "PySocks-1.7.1.tar.gz", hash = "sha256:3f8804571ebe159c380ac6de37643bb4685970655d3bba243530d6558b799aa0"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -2267,12 +2454,31 @@ files = [ certifi = ">=2017.4.17" charset-normalizer = ">=2,<4" idna = ">=2.5,<4" +PySocks = {version = ">=1.5.6,<1.5.7 || >1.5.7", optional = true, markers = "extra == \"socks\""} urllib3 = ">=1.21.1,<3" [package.extras] socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "retina-face" +version = "0.0.17" +description = "RetinaFace: Deep Face Detection Framework in TensorFlow for Python" +optional = false +python-versions = ">=3.5.5" +files = [ + {file = "retina-face-0.0.17.tar.gz", hash = "sha256:7532b136ed01fe9a8cba8dfbc5a046dd6fb1214b1a83e57f3210bd145a91cd73"}, + {file = "retina_face-0.0.17-py3-none-any.whl", hash = "sha256:b43fdac4078678b9d8bc45b88a7090f05d81c44e1e10710e6c16d703bb7add41"}, +] + +[package.dependencies] +gdown = ">=3.10.1" +numpy = ">=1.14.0" +opencv-python = ">=3.4.4" +Pillow = ">=5.2.0" +tensorflow = ">=1.9.0" + [[package]] name = "rich" version = "13.9.3" @@ -2572,6 +2778,17 @@ CFFI = ">=1.0" [package.extras] numpy = ["NumPy"] +[[package]] +name = "soupsieve" +version = "2.6" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.8" +files = [ + {file = "soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9"}, + {file = "soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb"}, +] + [[package]] name = "sympy" version = "1.13.1" @@ -3135,4 +3352,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "502cd6985714a46f2941ca82be24b94ef7b9ec7df3d3fa0f0b0d038385b65701" +content-hash = "746631e1ebed35c316b6840e25da0c1ce48b43da19fa63361392fccc7ae5e56b" diff --git a/pyproject.toml b/pyproject.toml index 8dcf456..3dfd891 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ ml-dtypes = "0.4.0" # installi tensorflow = "^2.17.0" mini-face = ">=0.1.0" depth-pro = { git = "https://github.com/child-lab-uj/depth-pro.git" } +deepface = "^0.0.93" [tool.poetry.group.dev.dependencies] poethepoet = "^0.26.1"