diff --git a/.coveragerc b/.coveragerc index d14242d..52a7b03 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,5 +4,4 @@ show_missing=true [run] omit = cameratokeyboard/app/ui.py - ci_train_and_upload.py */__init__.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index 7fbe04e..6d15ecb 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ .vscode yolov8n.pt .coverage -/dist \ No newline at end of file +/dist +/cameratokeyboard/model.pt \ No newline at end of file diff --git a/c2k.py b/c2k.py index 8a0983b..1480563 100644 --- a/c2k.py +++ b/c2k.py @@ -1,3 +1,4 @@ +# pylint: disable=missing-function-docstring import asyncio import sys @@ -5,13 +6,15 @@ from cameratokeyboard.config import Config from cameratokeyboard.args import parse_args from cameratokeyboard.app import App +from cameratokeyboard.model.model_downloader import ModelDownloader args = parse_args(sys.argv[1:]) async def async_main(): - app = App(Config.from_args(args)) - await app.run() + config = Config.from_args(args) + ModelDownloader(config).run() + await App(config).run() def main(): diff --git a/cameratokeyboard/app/detector.py b/cameratokeyboard/app/detector.py index e51b634..98969f3 100644 --- a/cameratokeyboard/app/detector.py +++ b/cameratokeyboard/app/detector.py @@ -1,12 +1,10 @@ import logging -import os import ultralytics from cameratokeyboard.config import Config from cameratokeyboard.core.detected_frame import DetectedFrame - -MODEL_PATH = os.path.join(os.path.dirname(__file__), "..", "model.pt") +from cameratokeyboard.model.model_downloader import ModelDownloader class Detector: @@ -19,9 +17,10 @@ class Detector: def __init__(self, config: Config) -> None: logging.getLogger("ultralytics").setLevel(logging.ERROR) + model_path = ModelDownloader(config).local_path_to_latest_model self._config = config - self._model = ultralytics.YOLO(MODEL_PATH) + self._model = ultralytics.YOLO(model_path) self._device = config.processing_device self._iou = config.iou self._detected_frame = None diff --git a/cameratokeyboard/config.py b/cameratokeyboard/config.py index 7cdf69e..73f7cbf 100644 --- a/cameratokeyboard/config.py +++ b/cameratokeyboard/config.py @@ -1,11 +1,15 @@ +# pylint: disable=missing-function-docstring,too-many-instance-attributes from dataclasses import dataclass import os +import platformdirs + EXCLUDED_KEYS = ["command"] +PRIVATE_KEYS = ["model_path"] @dataclass -class Config: # pylint: disable=too-many-instance-attributes +class Config: """ Application wide configuration. """ @@ -20,8 +24,6 @@ class Config: # pylint: disable=too-many-instance-attributes image_extension: str = "jpg" iou: float = 0.5 - model_path: str = os.path.join("cameratokeyboard", "model.pt") - resolution: tuple = (1280, 720) app_fps: int = 30 video_input_device: int = 0 @@ -35,9 +37,41 @@ class Config: # pylint: disable=too-many-instance-attributes keyboard_layout: str = "qwerty" repeating_keys_delay: float = 0.5 + remote_models_bucket_region: str = "eu-west-2" + remote_models_bucket_name: str = "c2k" + remote_models_prefix: str = "models/" + models_dir: str = os.path.join(platformdirs.user_data_dir(), "c2k", "models") + _model_path: str = None + + @property + def model_path(self) -> str: + if self._model_path: + return self._model_path + + models = [x for x in os.listdir(self.models_dir) if x.endswith(".pt")] + + def sort_key(model_name): + return os.path.getctime(os.path.join(self.models_dir, model_name)) + + models = sorted(models, key=sort_key, reverse=True) + try: + return os.path.join(self.models_dir, models[0]) + except IndexError: + return None + + @model_path.setter + def model_path(self, value): + self._model_path = value + @classmethod def from_args(cls, args: dict) -> "Config": """ Builds a Config object from the given arguments. """ - return cls(**{k: v for k, v in args.items() if k not in EXCLUDED_KEYS}) + processed_args = {k: v for k, v in args.items() if k not in EXCLUDED_KEYS} + for private_key in PRIVATE_KEYS: + if private_key in processed_args: + processed_args[f"_{private_key}"] = processed_args[private_key] + del processed_args[private_key] + + return cls(**processed_args) diff --git a/cameratokeyboard/model.pt b/cameratokeyboard/model.pt deleted file mode 100644 index 169ae3f..0000000 Binary files a/cameratokeyboard/model.pt and /dev/null differ diff --git a/cameratokeyboard/model/augmenter.py b/cameratokeyboard/model/augmenter.py index 36f6b33..24ba751 100644 --- a/cameratokeyboard/model/augmenter.py +++ b/cameratokeyboard/model/augmenter.py @@ -1,5 +1,4 @@ -# DEPRECATED: This and all related files will be removed once we move to the new data pipeline -# pylint: skip-file +# pylint: disable=too-many-locals from ast import literal_eval import os @@ -24,16 +23,37 @@ class ImageAugmenterStrategy: + """ + A class representing an image augmentation strategy. + + Attributes: + augmentation_strategies (List[List[ImageAugmenter]]): A list of lists of + ImageAugmenter objects representing the augmentation strategies to be applied. + images_path (str): The path to the directory containing the input images. + labels_path (str): The path to the directory containing the label files. + files (List[str]): A list of filenames in the images_path directory. + """ + def __init__(self, config: Config) -> None: + """ + Initializes an ImageAugmenterStrategy object. + + Args: + config (Config): The configuration object containing the necessary parameters. + + """ self.augmentation_strategies = [] self._resolve_strategies() train_path = config.split_paths[0] self.images_path = os.path.join(config.dataset_path, "images", train_path) self.labels_path = os.path.join(config.dataset_path, "labels", train_path) - self.files = [f for f in os.listdir(self.images_path)] + self.files = list(os.listdir(self.images_path)) def run(self) -> None: + """ + Runs the strategy + """ for strategy in self.augmentation_strategies: self._run_strategy(strategy) diff --git a/cameratokeyboard/model/augmenters.py b/cameratokeyboard/model/augmenters.py index a45161d..b856f1f 100644 --- a/cameratokeyboard/model/augmenters.py +++ b/cameratokeyboard/model/augmenters.py @@ -1,6 +1,3 @@ -# DEPRECATED: This and all related files will be removed once we move to the new data pipeline -# pylint: skip-file - from abc import ABC, abstractmethod from typing import Tuple, Union @@ -10,11 +7,25 @@ class ImageAugmenter(ABC): + """ + Abstract base class for image augmenters. + """ + @abstractmethod def __init__(self, *args, **kwargs): pass def apply(self, image: np.ndarray, bounding_boxes: str) -> Tuple[np.ndarray, str]: + """ + Apply image augmentation to the input image and bounding boxes. + + Args: + image (np.ndarray): The input image. + bounding_boxes (str): The bounding boxes in string format. + + Returns: + Tuple[np.ndarray, str]: The augmented image and the serialized bounding boxes. + """ bbs = self.parse_bounding_boxes(bounding_boxes, image.shape) aug_image, aug_bounding_boxes = self.augmenter(image=image, bounding_boxes=bbs) @@ -23,6 +34,16 @@ def apply(self, image: np.ndarray, bounding_boxes: str) -> Tuple[np.ndarray, str def parse_bounding_boxes( self, bounding_boxes: str, image_shape: Tuple[int, int] ) -> BoundingBoxesOnImage: + """ + Parse the bounding boxes from the string format to BoundingBoxesOnImage object. + + Args: + bounding_boxes (str): The bounding boxes in string format. + image_shape (Tuple[int, int]): The shape of the input image. + + Returns: + BoundingBoxesOnImage: The parsed bounding boxes. + """ if not bounding_boxes: return None @@ -45,7 +66,17 @@ def parse_bounding_boxes( return BoundingBoxesOnImage(parsed_bounding_boxes, shape=image_shape) - def serialize_bounding_boxes(bounding_boxes: BoundingBoxesOnImage) -> str: + @classmethod + def serialize_bounding_boxes(cls, bounding_boxes: BoundingBoxesOnImage) -> str: + """ + Serialize the bounding boxes from BoundingBoxesOnImage object to string format. + + Args: + bounding_boxes (BoundingBoxesOnImage): The bounding boxes. + + Returns: + str: The serialized bounding boxes. + """ if not bounding_boxes: return None @@ -65,6 +96,16 @@ def serialize_bounding_boxes(bounding_boxes: BoundingBoxesOnImage) -> str: class ScaleAugmenter(ImageAugmenter): + """ + A class representing a scale augmenter. + + This augmenter applies scaling transformations to images. + + Args: + min_scale (float): The minimum scale factor to apply. + max_scale (float): The maximum scale factor to apply. + """ + def __init__(self, min_scale: float, max_scale: float): self.min_scale = min_scale self.max_scale = max_scale @@ -77,6 +118,16 @@ def __repr__(self) -> str: class RotationAugmenter(ImageAugmenter): + """ + A class representing a rotation augmenter for images. + + Attributes: + min_angle (float): The minimum angle of rotation. + max_angle (float): The maximum angle of rotation. + augmenter (iaa.Affine): The image augmentation object. + + """ + def __init__(self, min_angle: float, max_angle: float): self.min_angle = min_angle self.max_angle = max_angle @@ -87,6 +138,16 @@ def __repr__(self) -> str: class VerticalFlipAugmenter(ImageAugmenter): + """ + A class representing a vertical flip augmenter. + + This augmenter flips images vertically. + + Attributes: + augmenter (imgaug.augmenters.Flipud): The vertical flip augmenter. + + """ + def __init__(self): self.augmenter = iaa.Flipud(True) @@ -95,6 +156,10 @@ def __repr__(self) -> str: class HorizontalFlipAugmenter(ImageAugmenter): + """ + Augmenter that performs horizontal flipping on images. + """ + def __init__(self): self.augmenter = iaa.Fliplr(True) @@ -103,13 +168,33 @@ def __repr__(self) -> str: class BlurAugmenter(ImageAugmenter): + """ + A class representing a blur augmenter that applies Gaussian blur to an image. + + Args: + min_sigma (float): The minimum standard deviation for the Gaussian blur. + max_sigma (float): The maximum standard deviation for the Gaussian blur. + """ + def __init__(self, min_sigma: float, max_sigma: float): self.min_sigma = min_sigma self.max_sigma = max_sigma + self.augmenter = None def apply( self, image: np.ndarray, bounding_boxes: str ) -> Tuple[Union[np.ndarray, str]]: + """ + Applies Gaussian blur to the input image. + + Args: + image (np.ndarray): The input image to be augmented. + bounding_boxes (str): The bounding boxes associated with the image. + + Returns: + Tuple[Union[np.ndarray, str]]: A tuple containing the augmented image and the + bounding boxes. + """ sigma = np.random.uniform(self.min_sigma, self.max_sigma) self.augmenter = iaa.GaussianBlur(sigma=sigma) return super().apply(image, bounding_boxes) @@ -119,6 +204,15 @@ def __repr__(self) -> str: class ShearAugmenter(ImageAugmenter): + """ + A class representing a shear augmenter for image data. + + Attributes: + min_shear (float): The minimum shear value to apply. + max_shear (float): The maximum shear value to apply. + augmenter (iaa.ShearX): The shear augmentation object. + """ + def __init__(self, min_shear: float, max_shear: float): self.min_shear = min_shear self.max_shear = max_shear @@ -129,6 +223,16 @@ def __repr__(self) -> str: class PerspectiveAugmenter(ImageAugmenter): + """ + A class representing a perspective augmenter for image data. + + This augmenter applies perspective transformations to images. + + Args: + min_scale (float): The minimum scale factor for the perspective transformation. + max_scale (float): The maximum scale factor for the perspective transformation. + """ + def __init__(self, min_scale: float, max_scale: float): self.min_scale = min_scale self.max_scale = max_scale diff --git a/cameratokeyboard/model/model_downloader.py b/cameratokeyboard/model/model_downloader.py new file mode 100644 index 0000000..d38e0c4 --- /dev/null +++ b/cameratokeyboard/model/model_downloader.py @@ -0,0 +1,93 @@ +import os + +import requests + +from cameratokeyboard.config import Config +from cameratokeyboard.logger import get_logger +from cameratokeyboard.utils.s3_response import S3Response + +logger = get_logger() +REQUESTS_TIMEOUT = 10.0 + + +class ModelDownloader: + """ + Checks for model updates and downloads the latest version if available. + """ + + def __init__(self, config: Config): + self.config = config + + def run(self) -> None: + """ + Runs the checker and downloader + """ + logger.info("Checking for new models...") + + os.makedirs(self.config.models_dir, exist_ok=True) + + latest_version_filename = self._get_latest_version_filename() + + if not latest_version_filename: + logger.warning("No remote models found!") + return + + if self._is_latest_version_downloaded(latest_version_filename): + logger.info("Latest model already downloaded.") + return + + logger.info("Found a new model version: %s", latest_version_filename) + + downloaded_path = self._download_model(latest_version_filename) + + if downloaded_path: + logger.info("Model downloaded to %s", downloaded_path) + + @property + def local_path_to_latest_model(self): + """ + Returns the local path to the latest model + """ + latest_verison_filename = self._get_latest_version_filename() + return os.path.join(self.config.models_dir, latest_verison_filename) + + @property + def _bucket_url(self): + bucket = self.config.remote_models_bucket_name + region = self.config.remote_models_bucket_region + return f"https://{bucket}.s3.{region}.amazonaws.com" + + def _get_latest_version_filename(self) -> str: + content = requests.get( + self._bucket_url, timeout=REQUESTS_TIMEOUT + ).content.decode() + parsed_content = [ + x + for x in S3Response(content).get_objects() + if x.key.startswith(self.config.remote_models_prefix) + ] + sorted_content = sorted(parsed_content, key=lambda x: x.last_modified) + return sorted_content[-1].key.replace(self.config.remote_models_prefix, "") + + def _is_latest_version_downloaded(self, latest_version: str) -> bool: + return latest_version in os.listdir(self.config.models_dir) + + def _download_model(self, version: str) -> str: + url = f"{self._bucket_url}/{self.config.remote_models_prefix}{version}" + model_path = os.path.join(self.config.models_dir, version) + + response = requests.get(url, timeout=REQUESTS_TIMEOUT, stream=True) + + if response.status_code == 200: + with open(model_path, "wb") as f: + f.write(response.content) + + return model_path + + logger.error( + "Could not download model %s: [%d] %s", + version, + response.status_code, + response.content, + ) + return None diff --git a/cameratokeyboard/model/partitioner.py b/cameratokeyboard/model/partitioner.py index d1becd8..ad409fa 100644 --- a/cameratokeyboard/model/partitioner.py +++ b/cameratokeyboard/model/partitioner.py @@ -1,5 +1,3 @@ -# DEPRECATED: This and all related files will be removed once we move to the new data pipeline - import os from random import shuffle import shutil @@ -50,8 +48,8 @@ def partition(self) -> None: def _read_files_list(self) -> None: LOGGER.info("Reading files list.") - self._files_list = list( - set(f.split(".")[0] for f in os.listdir(self._raw_dataset_path)) + self._files_list = sorted( + list(set(f.split(".")[0] for f in os.listdir(self._raw_dataset_path))) ) shuffle(self._files_list) diff --git a/cameratokeyboard/model/train.py b/cameratokeyboard/model/train.py index 369468e..28d14aa 100644 --- a/cameratokeyboard/model/train.py +++ b/cameratokeyboard/model/train.py @@ -1,6 +1,4 @@ -# DEPRECATED: This and all related files will be removed once we move to the new data pipeline -# pylint: skip-file - +import hashlib import os import shutil @@ -12,9 +10,24 @@ class Trainer: - def __init__(self, config: Config) -> None: + """ + The Trainer class is responsible for training the model using the provided configuration. + + Args: + config (Config): The configuration object containing the necessary parameters for training. + target_path (str): Copies the trained model to this location when done. If None, + defaults to the `models_dir` config value. + + """ + + def __init__(self, config: Config, target_path: str = None) -> None: self.config = config + if target_path is None: + self._target_path = config.models_dir + else: + self._target_path = target_path + self.raw_dataset_path = config.raw_dataset_path self.dataset_path = config.dataset_path self.split_paths = config.split_paths @@ -22,31 +35,33 @@ def __init__(self, config: Config) -> None: settings.update({"datasets_dir": os.path.join(os.getcwd(), "datasets")}) def run(self): + """ + Runs the training process and copies the trained model to target_path when done. + """ self._parition_data() - return self._train() - - def _are_training_data_up_to_date(self): - if not os.path.exists(self.dataset_path): - return False - - if any( - not os.path.exists(os.path.join(self.dataset_path, "images", split_name)) - for split_name in self.split_paths + self._train() + + def calc_next_version(self): + """ + Calculates the next version based on the checksums of the files in the raw dataset path. + + Returns: + str: The MD5 hash of the concatenated checksums of all files in the raw dataset path. + Returns None if the raw dataset path is empty or does not exist. + """ + if not os.path.exists(self.config.raw_dataset_path) or not os.listdir( + self.config.raw_dataset_path ): - return False + return None - raw_files = set(os.listdir(self.raw_dataset_path)) - files = set( - os.listdir(os.path.join(self.dataset_path, "images", "train")) - + os.listdir(os.path.join(self.dataset_path, "images", "test")) - + os.listdir(os.path.join(self.dataset_path, "images", "val")) - ) - return raw_files == files + checksums = [] + for file in os.listdir(self.config.raw_dataset_path): + with open(os.path.join(self.config.raw_dataset_path, file), "rb") as f: + checksums.append(hashlib.md5(f.read()).hexdigest()) - def _parition_data(self): - if self._are_training_data_up_to_date(): - return + return hashlib.md5("".join(checksums).encode("utf-8")).hexdigest() + def _parition_data(self): DataPartitioner(self.config).partition() ImageAugmenterStrategy(self.config).run() @@ -61,8 +76,7 @@ def _train(self): device=self.config.processing_device, ) + version = self.calc_next_version() model_path = os.path.join(results.save_dir, "weights", "best.pt") - target_model_path = os.path.join("cameratokeyboard", "model.pt") - shutil.copyfile(model_path, target_model_path) - - return results + os.makedirs(self._target_path, exist_ok=True) + shutil.copyfile(model_path, os.path.join(self._target_path, f"{version}.pt")) diff --git a/cameratokeyboard/types.py b/cameratokeyboard/types.py index bd61727..7d04dd9 100644 --- a/cameratokeyboard/types.py +++ b/cameratokeyboard/types.py @@ -136,3 +136,6 @@ def by_index(cls, index) -> Finger: Returns a finger by its index. """ return next((x for x in cls.values() if x.index == index), None) + + +S3ContentsItem = namedtuple("S3ContentsItem", ["key", "last_modified"]) diff --git a/cameratokeyboard/utils/__init__.py b/cameratokeyboard/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cameratokeyboard/utils/s3_response.py b/cameratokeyboard/utils/s3_response.py new file mode 100644 index 0000000..dea87cf --- /dev/null +++ b/cameratokeyboard/utils/s3_response.py @@ -0,0 +1,38 @@ +from datetime import datetime +from xml.dom.minidom import parseString + +from cameratokeyboard.types import S3ContentsItem + + +class S3Response: + """ + Parses the XML response of S3 api calls. + """ + + def __init__(self, response: str) -> None: + self._document = parseString(response) + + def get_objects(self): + """ + Returns the list of objects + """ + objects = [] + + for content in self._contents(): + try: + key = content.getElementsByTagName("Key")[0].firstChild.nodeValue + last_modified = content.getElementsByTagName("LastModified")[ + 0 + ].firstChild.nodeValue + last_modified = datetime.strptime( + last_modified, "%Y-%m-%dT%H:%M:%S.%fZ" + ) + + objects.append(S3ContentsItem(key=key, last_modified=last_modified)) + except IndexError: + continue + + return objects + + def _contents(self): + return self._document.getElementsByTagName("Contents") diff --git a/ci_train_and_upload.py b/ci_train_and_upload.py index c02d764..42a5cda 100644 --- a/ci_train_and_upload.py +++ b/ci_train_and_upload.py @@ -1,8 +1,6 @@ # pylint: disable=missing-function-docstring -import hashlib import os -import shutil import sys import tempfile @@ -12,41 +10,26 @@ from cameratokeyboard.logger import get_logger from cameratokeyboard.model.train import Trainer -REGION = os.environ["AWS_REGION"] -BUCKET_NAME = os.environ["AWS_BUCKET_NAME"] -RAW_DATASET_PATH = "raw_dataset" -REMOTE_MODELS_DIR = "models" - logger = get_logger() -s3_client = boto3.client("s3", region_name=REGION) - - -def get_next_version(): - logger.info("Calculating the checksum of the current dataset") +config = Config(processing_device="cpu") +s3_client = boto3.client("s3", region_name=config.remote_models_bucket_region) - if not os.path.exists(RAW_DATASET_PATH) or not os.listdir(RAW_DATASET_PATH): - logger.info("Raw dataset not found.") - return None +MODEL_TARGET_DIR = os.path.join(tempfile.tempdir, "c2k") - checksums = [] - for file in os.listdir(RAW_DATASET_PATH): - with open(os.path.join(RAW_DATASET_PATH, file), "rb") as f: - checksums.append(hashlib.md5(f.read()).hexdigest()) - - return hashlib.md5("".join(checksums).encode("utf-8")).hexdigest() +trainer = Trainer(config, target_path=MODEL_TARGET_DIR) def version_already_exists(version: str) -> bool: logger.info("Checking for changes") objects = s3_client.list_objects_v2( - Bucket=BUCKET_NAME, Prefix=f"models/{version}.pt" + Bucket=config.remote_models_bucket_name, Prefix=f"models/{version}.pt" ) return objects and objects.get("KeyCount", 0) > 0 def train(): - version = get_next_version() + version = trainer.calc_next_version() if not version: return @@ -59,33 +42,32 @@ def train(): logger.info("Training the model") - config = Config() - config.processing_device = "cpu" - trainer = Trainer(config) - - results = trainer.run() - - model_path = os.path.join(results.save_dir, "weights", "best.pt") - dest_path = os.path.join(tempfile.tempdir, "c2k") - os.makedirs(dest_path, exist_ok=True) - shutil.copyfile(model_path, os.path.join(dest_path, f"{version}.pt")) + trainer.run() - logger.info("Saved trained model to %s/%s.pt", dest_path, version) + logger.info("Saved trained model to %s/%s.pt", MODEL_TARGET_DIR, version) def upload_model(): - version = get_next_version() + version = trainer.calc_next_version() model_name = f"{version}.pt" - model_path = os.path.join(tempfile.tempdir, "c2k", model_name) + model_path = os.path.join(MODEL_TARGET_DIR, model_name) + + if not os.path.exists(model_path): + return + logger.info( - "Uploading %s to s3://%s/%s/%s", + "Uploading %s to s3://%s/%s%s", model_path, - BUCKET_NAME, - REMOTE_MODELS_DIR, + config.remote_models_bucket_name, + config.remote_models_prefix, model_name, ) - s3_client.upload_file(model_path, BUCKET_NAME, f"{REMOTE_MODELS_DIR}/{model_name}") + s3_client.upload_file( + model_path, + config.remote_models_bucket_name, + f"{config.remote_models_prefix}{model_name}", + ) def main(): diff --git a/requirements.txt b/requirements.txt index f4703d6..042fbf5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ colorlog ~= 4.6 imgaug ~= 0.4 numpy ~= 1.26 opencv-contrib-python ~= 4.9 +platformdirs ~= 4.2 pygame-ce ~= 2.4 pygame-gui ~= 0.6 pyyaml ~= 6.0 diff --git a/tests/app/test_detector.py b/tests/app/test_detector.py index 4e05892..7ff7f58 100644 --- a/tests/app/test_detector.py +++ b/tests/app/test_detector.py @@ -4,6 +4,13 @@ from cameratokeyboard.app.detector import Detector +from tests.fixtures import s3_objects_response + +MODELS_DIR = "/path/to/models" +BUCKET_NAME = "bucket_name" +REGION = "eu-west-2" +PREFIX = "models/" + class YoloMock: def __init__(self, *args, **kwargs): @@ -14,13 +21,24 @@ def __call__(self, *args, **kwargs): @pytest.fixture -def config(): - return Mock(iou=0.5, processing_device=0) +def base_config(): + return { + "iou": 0.5, + "models_dir": MODELS_DIR, + "remote_models_bucket_name": BUCKET_NAME, + "remote_models_bucket_region": REGION, + "remote_models_prefix": PREFIX, + } @pytest.fixture -def config_with_cpu(): - return Mock(iou=0.5, processing_device="cpu") +def config(base_config): + return Mock(processing_device=0, **base_config) + + +@pytest.fixture +def config_with_cpu(base_config): + return Mock(processing_device="cpu", **base_config) @pytest.fixture @@ -47,6 +65,13 @@ def detected_frame_class(): yield detected_frame_class +@pytest.fixture(autouse=True) +def requests_mock(): + with patch("cameratokeyboard.model.model_downloader.requests") as mock: + mock.get.return_value = MagicMock(content=s3_objects_response.encode("utf-8")) + yield mock + + def test_detect(yolo_mock, config, frame, detected_frame_class): detector = Detector(config) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 0000000..3efacff --- /dev/null +++ b/tests/fixtures/__init__.py @@ -0,0 +1,8 @@ +import os + +s3_objects_response = None + +with open( + os.path.join(os.path.dirname(__file__), "s3_objects.xml"), "r", encoding="utf-8" +) as f: + s3_objects_response = f.read() diff --git a/tests/fixtures/s3_objects.xml b/tests/fixtures/s3_objects.xml new file mode 100644 index 0000000..875d2ab --- /dev/null +++ b/tests/fixtures/s3_objects.xml @@ -0,0 +1,69 @@ + + + c2k + + + 1000 + false + + 20240323-131619.txt + 2024-03-23T13:16:21.000Z + "d8e8fca2dc0f896fd7cb4cb0031ba249" + 5 + + 588d9a2aaae3b0a06bfdcf7f27bd0c83088df2006464b627c48b450032c6bbc8 + + STANDARD + + + models/ + 2024-03-23T13:59:13.000Z + "d41d8cd98f00b204e9800998ecf8427e" + 0 + + 588d9a2aaae3b0a06bfdcf7f27bd0c83088df2006464b627c48b450032c6bbc8 + + STANDARD + + + models/72cf8b59b60538ba46cdda79bd38afaa.pt + 2024-03-25T01:07:33.000Z + "16f9e117c6df5fb165e455e9b9a3f596" + 6249049 + + 588d9a2aaae3b0a06bfdcf7f27bd0c83088df2006464b627c48b450032c6bbc8 + + STANDARD + + + models/daab7a39b2f4cea3e435dc12e89cbe9d.pt + 2024-03-24T23:20:26.000Z + "05fd8c051e03327e9c70f890cf37361f" + 6246809 + + 588d9a2aaae3b0a06bfdcf7f27bd0c83088df2006464b627c48b450032c6bbc8 + + STANDARD + + + test.txt + 2024-03-23T11:38:01.000Z + "6bd8c645880b0d1c7f226e3607eec3d0" + 10 + + 588d9a2aaae3b0a06bfdcf7f27bd0c83088df2006464b627c48b450032c6bbc8 + + STANDARD + + + test2.txt + 2024-03-23T11:52:07.000Z + "3144e4b96f7fd15227332e23b8537d89" + 3 + + 588d9a2aaae3b0a06bfdcf7f27bd0c83088df2006464b627c48b450032c6bbc8 + + STANDARD + + + \ No newline at end of file diff --git a/tests/model/test_augmenter.py b/tests/model/test_augmenter.py new file mode 100644 index 0000000..f95cbba --- /dev/null +++ b/tests/model/test_augmenter.py @@ -0,0 +1,58 @@ +# pylint: disable=missing-function-docstring,redefined-outer-name + +from unittest.mock import patch, MagicMock, mock_open + +import pytest + +from cameratokeyboard.config import Config +from cameratokeyboard.model.augmenter import ImageAugmenterStrategy + +FILES_LIST = ["file1.jpg", "file2.jpg"] +STRATEGIES = ["Scale:0.5,1.2", "Rotation:-45,45"] + + +@pytest.fixture +def config(): + return Config + + +@pytest.fixture +def scale_augmenter_mock(): + with patch("cameratokeyboard.model.augmenter.augmenters.ScaleAugmenter") as mock: + mock.return_value.apply.return_value = ["image", "label"] + yield mock + + +@pytest.fixture +def rotation_augmenter_mock(): + with patch("cameratokeyboard.model.augmenter.augmenters.RotationAugmenter") as mock: + mock.return_value.apply.return_value = ["image", "label"] + yield mock + + +@pytest.fixture +def cv2_mock(): + with patch("cameratokeyboard.model.augmenter.cv2") as mock: + yield mock + + +@pytest.fixture +def mock_file_open(): + with patch("builtins.open", mock_open(read_data="data")) as mock: + yield mock + + +@patch("cameratokeyboard.model.augmenter.STRATEGIES", STRATEGIES) +@patch( + "cameratokeyboard.model.augmenter.os.listdir", MagicMock(return_value=FILES_LIST) +) +def test_run( + config, scale_augmenter_mock, rotation_augmenter_mock, mock_file_open, cv2_mock +): + strategy = ImageAugmenterStrategy(config) + strategy.run() + + assert cv2_mock.imread.called + assert scale_augmenter_mock.return_value.apply.called + assert rotation_augmenter_mock.return_value.apply.called + assert cv2_mock.imwrite.called diff --git a/tests/model/test_augmenters.py b/tests/model/test_augmenters.py new file mode 100644 index 0000000..f11fccc --- /dev/null +++ b/tests/model/test_augmenters.py @@ -0,0 +1,70 @@ +# pylint: disable=missing-function-docstring +from unittest.mock import patch, MagicMock + +import imgaug +import numpy as np + +from cameratokeyboard.model.augmenters import ( + ImageAugmenter, + ScaleAugmenter, + RotationAugmenter, + VerticalFlipAugmenter, + HorizontalFlipAugmenter, + BlurAugmenter, + ShearAugmenter, + PerspectiveAugmenter, +) + + +def test_scale_augmenter(): + assert issubclass(ScaleAugmenter, ImageAugmenter) + augmenter = ScaleAugmenter(0.05, 1.0) + assert isinstance(augmenter.augmenter, imgaug.augmenters.Affine) + + +def test_Rotation_augmenter(): + assert issubclass(RotationAugmenter, ImageAugmenter) + augmenter = ScaleAugmenter(0.05, 1.0) + assert isinstance(augmenter.augmenter, imgaug.augmenters.Affine) + + +def test_vertical_flip_augmenter(): + assert issubclass(VerticalFlipAugmenter, ImageAugmenter) + augmenter = VerticalFlipAugmenter() + assert isinstance(augmenter.augmenter, imgaug.augmenters.Flipud) + + +def test_horizontal_flip_augmenter(): + assert issubclass(HorizontalFlipAugmenter, ImageAugmenter) + augmenter = HorizontalFlipAugmenter() + assert isinstance(augmenter.augmenter, imgaug.augmenters.Fliplr) + + +@patch("cameratokeyboard.model.augmenters.iaa.GaussianBlur") +def test_blur_augmenter(guassian_blur_mock): + guassian_blur_mock.return_value.return_value = ( + "image", + MagicMock( + bounding_boxes=[ + MagicMock(label="0", center_x=10, center_y=10, width=10, height=10) + ], + shape=(20, 20), + ), + ) + assert issubclass(BlurAugmenter, ImageAugmenter) + augmenter = BlurAugmenter(0.5, 1) + augmenter.apply(np.zeros((100, 100, 1)), MagicMock()) + + assert guassian_blur_mock.return_value.called + + +def test_shear_augmenter(): + assert issubclass(ShearAugmenter, ImageAugmenter) + augmenter = ShearAugmenter(1.0, 2.0) + assert isinstance(augmenter.augmenter, imgaug.augmenters.ShearX) + + +def test_perspective_augmenter(): + assert issubclass(PerspectiveAugmenter, ImageAugmenter) + augmenter = PerspectiveAugmenter(1.0, 2.0) + assert isinstance(augmenter.augmenter, imgaug.augmenters.PerspectiveTransform) diff --git a/tests/model/test_model_downloader.py b/tests/model/test_model_downloader.py new file mode 100644 index 0000000..c32064b --- /dev/null +++ b/tests/model/test_model_downloader.py @@ -0,0 +1,84 @@ +# pylint: disable=missing-function-docstring,redefined-outer-name + +from datetime import datetime, timedelta +from unittest.mock import call, patch, MagicMock + +import pytest + +from cameratokeyboard.model.model_downloader import ModelDownloader +from tests.fixtures import s3_objects_response + +MODELS_DIR = "/path/to/models" +BUCKET_NAME = "bucket_name" +REGION = "eu-west-2" +PREFIX = "models/" +LATEST_MODEL = "file1.pt" + + +@pytest.fixture +def config_mock(): + return MagicMock( + models_dir=MODELS_DIR, + remote_models_bucket_name=BUCKET_NAME, + remote_models_bucket_region=REGION, + remote_models_prefix=PREFIX, + ) + + +@pytest.fixture +def os_mock(): + with patch("cameratokeyboard.model.model_downloader.os") as mock: + mock.listdir.return_value = [] + yield mock + + +@pytest.fixture +def os_model_exists_mock(): + with patch("cameratokeyboard.model.model_downloader.os") as mock: + mock.listdir.return_value = [LATEST_MODEL] + yield mock + + +@pytest.fixture +def requests_mock(): + def get_mock(url, *args, **kwargs): + if url == "https://bucket_name.s3.eu-west-2.amazonaws.com": + return MagicMock(content=s3_objects_response.encode("utf-8")) + else: + return MagicMock(content="thefilecontents".encode("utf-8")) + + with patch("cameratokeyboard.model.model_downloader.requests") as mock: + mock.get.side_effect = get_mock + yield mock + + +def test_run(config_mock, os_mock, requests_mock): + model_downloader = ModelDownloader(config_mock) + model_downloader.run() + + assert os_mock.makedirs.call_args_list == [call("/path/to/models", exist_ok=True)] + print(requests_mock.get.call_args_list) + assert requests_mock.get.call_args_list == [ + call("https://bucket_name.s3.eu-west-2.amazonaws.com", timeout=10.0), + call( + "https://bucket_name.s3.eu-west-2.amazonaws.com/models/72cf8b59b60538ba46cdda79bd38afaa.pt", + timeout=10.0, + stream=True, + ), + ] + + +# def test_run_no_remote_models(config_mock, os_mock, boto3_client_no_models_mock): +# model_downloader = ModelDownloader(config_mock) +# model_downloader.run() +# +# assert not boto3_client_no_models_mock.return_value.download_file.called +# +# +# def test_run_model_already_downloaded( +# config_mock, os_model_exists_mock, boto3_client_mock +# ): +# model_downloader = ModelDownloader(config_mock) +# model_downloader.run() +# +# assert not boto3_client_mock.return_value.download_file.called diff --git a/tests/model/test_partitioner.py b/tests/model/test_partitioner.py new file mode 100644 index 0000000..309ff7d --- /dev/null +++ b/tests/model/test_partitioner.py @@ -0,0 +1,126 @@ +# pylint: disable=missing-function-docstring,redefined-outer-name + +from unittest.mock import patch, MagicMock, call + +import pytest + +from cameratokeyboard.config import Config +from cameratokeyboard.model.partitioner import DataPartitioner + +IMAGES_LIST = ["image1.jpg", "image2.jpg", "image3.jpg", "image4.jpg"] +LABELS_LIST = ["image1.txt", "image2.txt", "image3.txt", "image4.txt"] + + +@pytest.fixture +def config(): + config = Config() + config.split_ratios = [0.5, 0.25, 0.25] + + return config + + +def makedirs_call_args(config): + return [ + call(path) + for path in [ + f"{config.dataset_path}/images/train", + f"{config.dataset_path}/labels/train", + f"{config.dataset_path}/images/test", + f"{config.dataset_path}/labels/test", + f"{config.dataset_path}/images/val", + f"{config.dataset_path}/labels/val", + ] + ] + + +def copyfile_call_args(config): + return [ + call( + f"{config.raw_dataset_path}/image1.jpg", + f"{config.dataset_path}/images/train/00000.jpg", + ), + call( + f"{config.raw_dataset_path}/image1.txt", + f"{config.dataset_path}/labels/train/00000.txt", + ), + call( + f"{config.raw_dataset_path}/image2.jpg", + f"{config.dataset_path}/images/train/00001.jpg", + ), + call( + f"{config.raw_dataset_path}/image2.txt", + f"{config.dataset_path}/labels/train/00001.txt", + ), + call( + f"{config.raw_dataset_path}/image3.jpg", + f"{config.dataset_path}/images/test/00002.jpg", + ), + call( + f"{config.raw_dataset_path}/image3.txt", + f"{config.dataset_path}/labels/test/00002.txt", + ), + call( + f"{config.raw_dataset_path}/image4.jpg", + f"{config.dataset_path}/images/val/00003.jpg", + ), + call( + f"{config.raw_dataset_path}/image4.txt", + f"{config.dataset_path}/labels/val/00003.txt", + ), + ] + + +@pytest.fixture +def os_mock_all_paths_exist(): + def join(*args): + return "/".join(args) + + with patch("cameratokeyboard.model.partitioner.os") as mock: + mock.listdir.return_value = IMAGES_LIST + mock.path.exists.return_value = True + mock.path.join = join + yield mock + + +@pytest.fixture +def os_mock_some_paths_dont_exist(os_mock_all_paths_exist): + def exists(file): + if file.endswith("txt"): + return False + return True + + os_mock_all_paths_exist.path.exists = exists + + yield os_mock_all_paths_exist + + +@pytest.fixture +def shutil_mock(): + with patch("cameratokeyboard.model.partitioner.shutil") as mock: + yield mock + + +@pytest.fixture +def shuffle_mock(): + def shuffle(iterable): + return iterable + + with patch("cameratokeyboard.model.partitioner.shuffle", shuffle) as mock: + yield mock + + +def test_partition(config, os_mock_all_paths_exist, shutil_mock, shuffle_mock): + DataPartitioner(config).partition() + + assert shutil_mock.rmtree.call_args_list == [call(config.dataset_path)] + assert os_mock_all_paths_exist.makedirs.call_args_list == makedirs_call_args(config) + + print(shutil_mock.copyfile.call_args_list) + assert shutil_mock.copyfile.call_args_list == copyfile_call_args(config) + + +def test_partition_labels_missing(config, os_mock_some_paths_dont_exist, shutil_mock): + partitioner = DataPartitioner(config) + + with pytest.raises(ValueError): + partitioner.partition() diff --git a/tests/model/test_train.py b/tests/model/test_train.py new file mode 100644 index 0000000..09fa4b0 --- /dev/null +++ b/tests/model/test_train.py @@ -0,0 +1,99 @@ +# pylint: disable=missing-function-docstring,redefined-outer-name,unused-argument +from unittest.mock import patch, call, mock_open + +import platformdirs +import pytest + +from cameratokeyboard.config import Config +from cameratokeyboard.model.train import Trainer + + +@pytest.fixture +def config(): + return Config() + + +@pytest.fixture +def path_exists_mock(): + with patch("cameratokeyboard.model.train.os.path.exists") as mock: + mock.return_value = True + yield mock + + +@pytest.fixture +def listdir_mock(): + with patch("cameratokeyboard.model.train.os.listdir") as mock: + mock.return_value = ["file1", "file2"] + yield mock + + +@pytest.fixture +def makedirs_mock(): + with patch("cameratokeyboard.model.train.os.makedirs") as mock: + yield mock + + +@pytest.fixture +def file_mock(): + with patch("builtins.open", mock_open(read_data="a".encode("utf-8"))) as mock: + yield mock + + +@pytest.fixture +def partitioner_mock(): + with patch("cameratokeyboard.model.train.DataPartitioner") as mock: + yield mock + + +@pytest.fixture +def augmenter_mock(): + with patch("cameratokeyboard.model.train.ImageAugmenterStrategy") as mock: + yield mock + + +@pytest.fixture +def yolo_mock(): + with patch("cameratokeyboard.model.train.YOLO") as mock: + mock.return_value.train.return_value.save_dir = "/somewhere" + yield mock + + +@pytest.fixture +def copyfile_mock(): + with patch("cameratokeyboard.model.train.shutil.copyfile") as mock: + yield mock + + +def test_calc_next_version(config, path_exists_mock, listdir_mock, file_mock): + trainer = Trainer(config) + assert trainer.calc_next_version() == "f3a0377ce26903122eb91b2851f97c96" + + +def test_train_default_path( + config, + path_exists_mock, + listdir_mock, + file_mock, + makedirs_mock, + partitioner_mock, + augmenter_mock, + yolo_mock, + copyfile_mock, +): + trainer = Trainer(config) + trainer.run() + target_dir = f"{platformdirs.user_data_dir()}/c2k/models" + + assert partitioner_mock.return_value.partition.called + assert augmenter_mock.return_value.run.called + + assert makedirs_mock.called and makedirs_mock.call_args_list == [ + call(target_dir, exist_ok=True) + ] + + assert copyfile_mock.called and copyfile_mock.call_args_list == [ + call( + "/somewhere/weights/best.pt", + f"{target_dir}/f3a0377ce26903122eb91b2851f97c96.pt", + ) + ] diff --git a/tests/test_ci_train_and_upload.py b/tests/test_ci_train_and_upload.py new file mode 100644 index 0000000..9e4197c --- /dev/null +++ b/tests/test_ci_train_and_upload.py @@ -0,0 +1,82 @@ +# pylint: disable=missing-function-docstring,redefined-outer-name,unused-argument +from unittest.mock import patch, call +import tempfile + +import pytest + +from ci_train_and_upload import version_already_exists, train, upload_model + +MODEL_VERSION = "model_version" + + +@pytest.fixture +def s3_client_mock_model_exists(): + with patch("ci_train_and_upload.s3_client") as s3_mock: + s3_mock.list_objects_v2.return_value = {"KeyCount": 1} + yield s3_mock + + +@pytest.fixture +def s3_client_mock_model_does_not_exists(): + with patch("ci_train_and_upload.s3_client") as s3_mock: + s3_mock.list_objects_v2.return_value = {"KeyCount": 0} + yield s3_mock + + +@pytest.fixture +def trainer_mock(): + with patch("ci_train_and_upload.trainer") as mock: + mock.calc_next_version.return_value = MODEL_VERSION + yield mock + + +@pytest.fixture +def os_mock_model_exists(): + with patch("ci_train_and_upload.os.path.exists") as os_mock: + os_mock.return_value = True + yield os_mock + + +@pytest.fixture +def os_mock_model_does_not_exists(): + with patch("ci_train_and_upload.os.path.exists") as os_mock: + os_mock.return_value = False + yield os_mock + + +def test_version_already_exists_true(s3_client_mock_model_exists): + assert version_already_exists(MODEL_VERSION) + assert s3_client_mock_model_exists.list_objects_v2.called + + +def test_version_already_exists_false(s3_client_mock_model_does_not_exists): + assert not version_already_exists(MODEL_VERSION) + + +def test_train(s3_client_mock_model_does_not_exists, trainer_mock): + train() + assert trainer_mock.run.called + + +def test_train_version_already_exists(s3_client_mock_model_exists, trainer_mock): + train() + assert not trainer_mock.run.called + + +def test_upload( + s3_client_mock_model_does_not_exists, trainer_mock, os_mock_model_exists +): + upload_model() + assert s3_client_mock_model_does_not_exists.upload_file.called + assert s3_client_mock_model_does_not_exists.upload_file.call_args_list == [ + call( + f"{tempfile.tempdir}/c2k/model_version.pt", "c2k", "models/model_version.pt" + ) + ] + + +def test_upload_model_does_not_exist( + s3_client_mock_model_does_not_exists, trainer_mock, os_mock_model_does_not_exists +): + upload_model() + assert not s3_client_mock_model_does_not_exists.upload_file.called diff --git a/tests/test_config.py b/tests/test_config.py index 4f9e1ce..4ef4db5 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,11 +1,19 @@ # pylint: disable=missing-function-docstring - import os +from unittest.mock import patch + +import platformdirs from cameratokeyboard.config import Config -def test_config_defaults(): +@patch("cameratokeyboard.config.os.path.getctime") +@patch("cameratokeyboard.config.os.listdir") +def test_config_defaults(listdir_mock, getctime_mock): + model_name = "latest_model.pt" + listdir_mock.return_value = [model_name] + getctime_mock.return_value = 0 + config = Config() assert config.training_epochs == 20 assert config.training_image_size == (640, 640) @@ -16,7 +24,6 @@ def test_config_defaults(): assert config.split_ratios == (0.7, 0.15, 0.15) assert config.image_extension == "jpg" assert config.iou == 0.5 - assert config.model_path == os.path.join("cameratokeyboard", "model.pt") assert config.resolution == (1280, 720) assert config.app_fps == 30 assert config.video_input_device == 0 @@ -26,6 +33,10 @@ def test_config_defaults(): assert config.thumbs_min_confidence == 0.3 assert config.key_down_sensitivity == 0.75 assert config.repeating_keys_delay == 0.5 + assert config.models_dir == os.path.join( + platformdirs.user_data_dir(), "c2k", "models" + ) + assert config.model_path == os.path.join(config.models_dir, model_name) def test_config_builder(): diff --git a/tests/utils/test_s3_response.py b/tests/utils/test_s3_response.py new file mode 100644 index 0000000..3b6d768 --- /dev/null +++ b/tests/utils/test_s3_response.py @@ -0,0 +1,29 @@ +# pylint: disable=missing-function-docstring +from datetime import datetime +from cameratokeyboard.utils.s3_response import S3Response +from cameratokeyboard.types import S3ContentsItem + +from tests.fixtures import s3_objects_response + + +def test_get_objects(): + parsed_response = S3Response(s3_objects_response) + + print(parsed_response.get_objects()) + assert parsed_response.get_objects() == [ + S3ContentsItem( + key="20240323-131619.txt", + last_modified=datetime(2024, 3, 23, 13, 16, 21), + ), + S3ContentsItem(key="models/", last_modified=datetime(2024, 3, 23, 13, 59, 13)), + S3ContentsItem( + key="models/72cf8b59b60538ba46cdda79bd38afaa.pt", + last_modified=datetime(2024, 3, 25, 1, 7, 33), + ), + S3ContentsItem( + key="models/daab7a39b2f4cea3e435dc12e89cbe9d.pt", + last_modified=datetime(2024, 3, 24, 23, 20, 26), + ), + S3ContentsItem(key="test.txt", last_modified=datetime(2024, 3, 23, 11, 38, 1)), + S3ContentsItem(key="test2.txt", last_modified=datetime(2024, 3, 23, 11, 52, 7)), + ]