diff --git a/client/data/routines.py b/client/data/routines.py index 6950220..69ebe28 100644 --- a/client/data/routines.py +++ b/client/data/routines.py @@ -6,11 +6,14 @@ from typing import Any, NamedTuple, Optional from importlib import resources +import cv2 +import numpy as np from pydbml import PyDBML RESOURCES = resources.files("data.resources") DATABASE_DEFINITION = RESOURCES.joinpath("database.dbml") DATABASE_RESOURCE = RESOURCES.joinpath("database.db") +FACES_FOLDER = RESOURCES.joinpath("faces") class User(NamedTuple): @@ -175,6 +178,22 @@ def get_user_postures( return [Posture(*record) for record in result.fetchall()] +def register_faces(user_id: int, faces: list[np.ndarray]) -> None: + """Register faces for a user. + + Args: + user_id: The user to register faces for. + faces: List of face arrays in the format HxWxC where channels are RGB + """ + with resources.as_file(FACES_FOLDER) as faces_folder: + faces_folder.mkdir(exist_ok=True) + user_folder = faces_folder / str(user_id) + user_folder.mkdir() + for i, image in enumerate(faces): + image_path = user_folder / f"{i}.png" + cv2.imwrite(str(image_path), image) + + def get_schema_info() -> list[list[tuple[Any]]]: """Column information on all tables in database. diff --git a/scripts/register_faces.py b/scripts/register_faces.py new file mode 100644 index 0000000..441f0f6 --- /dev/null +++ b/scripts/register_faces.py @@ -0,0 +1,30 @@ +import logging + +import cv2 + +from data.routines import destroy_database, init_database, register_faces + +logger = logging.getLogger(__name__) + + +def main(): + logging.basicConfig(level=logging.DEBUG) + logger.debug("Destroying db") + destroy_database() + logger.debug("Init db") + init_database() + + video = cv2.VideoCapture(0) + faces = [] + for _ in range(5): + _, frame = video.read() + faces.append(frame) + input("Enter to continue: ") + + video.release() + logger.debug("Registering faces") + register_faces(1, faces) + + +if __name__ == "__main__": + main()