From b32ad5eaea4a0e9599eee3f6dbbf638bce51a54a Mon Sep 17 00:00:00 2001 From: hh-space-invader Date: Wed, 4 Dec 2024 10:38:41 +0200 Subject: [PATCH] refactor: Refactor some functions in preprocess image --- fastembed/image/transform/functional.py | 2 +- fastembed/image/transform/operators.py | 15 +++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/fastembed/image/transform/functional.py b/fastembed/image/transform/functional.py index 47724515..8bd25427 100644 --- a/fastembed/image/transform/functional.py +++ b/fastembed/image/transform/functional.py @@ -140,7 +140,7 @@ def pad2square( image = image.crop((left, top, right, bottom)) return image - new_image = Image.new(mode="RGB", size=(size, size), color=fill_color) + new_image = Image.new(mode="RGB", size=(size, size), color=fill_color or 0) left = (size - width) // 2 top = (size - height) // 2 new_image.paste(image, (left, top)) diff --git a/fastembed/image/transform/operators.py b/fastembed/image/transform/operators.py index 0d08963d..82e2267b 100644 --- a/fastembed/image/transform/operators.py +++ b/fastembed/image/transform/operators.py @@ -38,9 +38,7 @@ def __init__(self, mean: Union[float, list[float]], std: Union[float, list[float self.std = std def __call__(self, images: list[np.ndarray]) -> list[np.ndarray]: - return [ - normalize(image, mean=np.array(self.mean), std=np.array(self.std)) for image in images - ] + return [normalize(image, mean=self.mean, std=self.std) for image in images] class Resize(Transform): @@ -183,11 +181,12 @@ def _get_resize(transforms: list[Transform], config: dict[str, Any]): ) ) elif mode == "JinaCLIPImageProcessor": - resample = ( - Compose._interpolation_resolver(config.get("interpolation")) - if isinstance(config.get("interpolation"), str) - else config.get("interpolation", Image.Resampling.BICUBIC) - ) + interpolation = config.get("interpolation") + if isinstance(interpolation, str): + resample = Compose._interpolation_resolver(interpolation) + else: + resample = interpolation or Image.Resampling.BICUBIC + if "size" in config: resize_mode = config.get("resize_mode", "shortest") if resize_mode == "shortest":