Skip to content
This repository has been archived by the owner on Dec 8, 2024. It is now read-only.

Commit

Permalink
Merge pull request #54 from LimaoC/mitch-register-faces
Browse files Browse the repository at this point in the history
Implemented register_faces
  • Loading branch information
MitchellJC authored Sep 13, 2024
2 parents 9de1b4d + 854d65d commit 4088bc8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
19 changes: 19 additions & 0 deletions client/data/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from typing import Any, NamedTuple, Optional
from importlib import resources

import cv2
import numpy as np
from pydbml import PyDBML

RESOURCES = resources.files("data.resources")
DATABASE_DEFINITION = RESOURCES.joinpath("database.dbml")
DATABASE_RESOURCE = RESOURCES.joinpath("database.db")
FACES_FOLDER = RESOURCES.joinpath("faces")


class User(NamedTuple):
Expand Down Expand Up @@ -175,6 +178,22 @@ def get_user_postures(
return [Posture(*record) for record in result.fetchall()]


def register_faces(user_id: int, faces: list[np.ndarray]) -> None:
"""Register faces for a user.
Args:
user_id: The user to register faces for.
faces: List of face arrays in the format HxWxC where channels are RGB
"""
with resources.as_file(FACES_FOLDER) as faces_folder:
faces_folder.mkdir(exist_ok=True)
user_folder = faces_folder / str(user_id)
user_folder.mkdir()
for i, image in enumerate(faces):
image_path = user_folder / f"{i}.png"
cv2.imwrite(str(image_path), image)


def get_schema_info() -> list[list[tuple[Any]]]:
"""Column information on all tables in database.
Expand Down
30 changes: 30 additions & 0 deletions scripts/register_faces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

import cv2

from data.routines import destroy_database, init_database, register_faces

logger = logging.getLogger(__name__)


def main():
logging.basicConfig(level=logging.DEBUG)
logger.debug("Destroying db")
destroy_database()
logger.debug("Init db")
init_database()

video = cv2.VideoCapture(0)
faces = []
for _ in range(5):
_, frame = video.read()
faces.append(frame)
input("Enter to continue: ")

video.release()
logger.debug("Registering faces")
register_faces(1, faces)


if __name__ == "__main__":
main()

0 comments on commit 4088bc8

Please sign in to comment.