Skip to content

Commit

Permalink
Merge pull request #1141 from JohnSnowLabs/feature/random-masking-on-…
Browse files Browse the repository at this point in the history
…images-tests

added new overlay classes for enhanced image robustness
  • Loading branch information
chakravarthik27 authored Nov 18, 2024
2 parents 9eaf24d + 7d55902 commit 12cc07e
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 8 deletions.
136 changes: 128 additions & 8 deletions langtest/transform/image/robustness.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Literal, Tuple, Union
from langtest.logger import logger
from langtest.transform.robustness import BaseRobustness
from langtest.transform.utils import get_default_font
from langtest.utils.custom_types.sample import Sample
from PIL import Image, ImageFilter, ImageDraw

Expand Down Expand Up @@ -403,6 +404,7 @@ def transform(
*args,
**kwargs,
) -> List[Sample]:
reset_mask = mask
for sample in sample_list:
sample.category = "robustness"
sample.test_type = "image_layered_mask"
Expand All @@ -422,6 +424,8 @@ def transform(
mask.putalpha(int(255 * opacity))
sample.perturbed_image.paste(mask, (0, 0), mask)

mask = reset_mask

return sample_list


Expand All @@ -432,23 +436,24 @@ class ImageTextOverlay(BaseRobustness):
@staticmethod
def transform(
sample_list: List[Sample],
text: str = "Hello, World!",
font_size: int = 20,
text: str = "LangTest",
font_size: int = 100,
font_color: Tuple[int, int, int] = (255, 255, 255),
position: Tuple[int, int] = (10, 10),
*args,
**kwargs,
) -> List[Sample]:
from PIL import ImageFont
# transperant text overlay on the image
font_color = font_color + (255,)

for sample in sample_list:
sample.category = "robustness"
sample.test_type = "image_text_overlay"
sample.perturbed_image = sample.original_image.copy()
draw = ImageDraw.Draw(sample.perturbed_image)
font = ImageFont.load_default()
font = get_default_font(font_size)

draw.text(
position,
(sample.original_image.width // 2, sample.original_image.height // 2),
text,
font=font,
fill=font_color,
Expand All @@ -464,7 +469,7 @@ class ImageWatermark(BaseRobustness):
@staticmethod
def transform(
sample_list: List[Sample],
watermark: Union[Image.Image, str],
watermark: Union[Image.Image, str] = None,
position: Tuple[int, int] = (10, 10),
opacity: float = 0.5,
*args,
Expand All @@ -481,7 +486,7 @@ def transform(
draw = ImageDraw.Draw(watermark)
draw.text(
position,
"Watermark",
"LangTest",
font=None,
fill=(255, 255, 255, int(255 * opacity)),
)
Expand All @@ -496,3 +501,118 @@ def transform(
sample.perturbed_image.paste(watermark, (0, 0), watermark)

return sample_list


class ImageRandomTextOverlay(BaseRobustness):
alias_name = "image_random_text_overlay"
supported_tasks = ["visualqa"]

@staticmethod
def transform(
sample_list: List[Sample],
opacity: float = 0.5,
font_size: int = 30,
random_texts: int = 10,
color: Tuple[int, int, int] = (0, 0, 0),
*args,
**kwargs,
) -> List[Sample]:
for sample in sample_list:
sample.category = "robustness"
sample.test_type = "image_random_text_overlay"
sample.perturbed_image = sample.original_image.copy()
overlay = Image.new("RGBA", sample.original_image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)

for _ in range(random_texts):
font = get_default_font(font_size)
x1 = random.randint(0, sample.original_image.width)
y1 = random.randint(0, sample.original_image.height)

draw.text(
(x1, y1),
"LangTest",
font=font,
fill=(*color, int(255 * opacity)),
)
sample.perturbed_image.paste(overlay, (0, 0), overlay)

return sample_list


class ImageRandomLineOverlay(BaseRobustness):
alias_name = "image_random_line_overlay"
supported_tasks = ["visualqa"]

@staticmethod
def transform(
sample_list: List[Sample],
color: Tuple[int, int, int] = (255, 0, 0),
opacity: float = 0.5,
random_lines: int = 10,
*args,
**kwargs,
) -> List[Sample]:
for sample in sample_list:
sample.category = "robustness"
sample.test_type = "image_random_line_overlay"
sample.perturbed_image = sample.original_image.copy()
overlay = Image.new("RGBA", sample.original_image.size)
overlay.putalpha(int(255 * opacity))
draw = ImageDraw.Draw(overlay)

for _ in range(random_lines):
# Get random points for the line
x1 = random.randint(0, sample.original_image.width)
y1 = random.randint(0, sample.original_image.height)
x2 = random.randint(0, sample.original_image.width)
y2 = random.randint(0, sample.original_image.height)

draw.line(
[(x1, y1), (x2, y2)],
fill=color + (int(255 * opacity),),
width=5,
)
sample.perturbed_image.paste(overlay, (0, 0), overlay)

return sample_list


class ImageRandomPolygonOverlay(BaseRobustness):
alias_name = "image_random_polygon_overlay"
supported_tasks = ["visualqa"]

@staticmethod
def transform(
sample_list: List[Sample],
color: Tuple[int, int, int] = (255, 0, 0),
opacity: float = 0.2,
random_polygons: int = 10,
*args,
**kwargs,
) -> List[Sample]:
for sample in sample_list:
sample.category = "robustness"
sample.test_type = "image_random_polygon_overlay"
sample.perturbed_image = sample.original_image.copy()
overlay = Image.new("RGBA", sample.original_image.size)
overlay.putalpha(int(255 * opacity))
draw = ImageDraw.Draw(overlay)

for _ in range(random_polygons):
# Get random points for the polygon vertices with random vertices
vertices = [
(
random.randint(0, sample.original_image.width),
random.randint(0, sample.original_image.height),
)
for _ in range(random.randint(3, 6))
]

draw.polygon(
vertices,
fill=color + (int(255 * opacity),),
)
sample.perturbed_image.paste(overlay, (0, 0), overlay)

return sample_list
23 changes: 23 additions & 0 deletions langtest/transform/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
bad_word_list,
)
from .custom_data import add_custom_data
from PIL import ImageFont
import os
import sys


class RepresentationOperation:
Expand Down Expand Up @@ -469,3 +472,23 @@ def compare_generations_overlap(phrase) -> int:
Answer (only A, B, C, or D):'
{assistant}
"""


def get_default_font(font_size=20):
"""
Returns a common font path available on all major operating systems.
Uses a fallback strategy for compatibility.
"""
if os.name == "nt": # Windows
return ImageFont.truetype("arial.ttf", font_size)
elif sys.platform == "darwin": # macOS
return ImageFont.truetype(
"/System/Library/Fonts/Supplemental/Helvetica.ttf", font_size
)
else: # Linux
try:
return ImageFont.truetype(
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size
)
except OSError:
return ImageFont.load_default()

0 comments on commit 12cc07e

Please sign in to comment.