From 177b0a6104d16867cf37ea58488142e64b136e26 Mon Sep 17 00:00:00 2001 From: mertyg Date: Sun, 7 Jul 2024 10:23:00 -0700 Subject: [PATCH] check for variable values --- textgrad/utils/image_utils.py | 9 ++++++++- textgrad/variable.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/textgrad/utils/image_utils.py b/textgrad/utils/image_utils.py index 706f529..fb75c2d 100644 --- a/textgrad/utils/image_utils.py +++ b/textgrad/utils/image_utils.py @@ -4,10 +4,17 @@ from urllib.parse import urlparse from typing import Union import platformdirs -import base64 + +from urllib.parse import urlparse + +def is_valid_url(url): + result = urlparse(url) + return all([result.scheme, result.netloc]) + def download_and_cache_image(image_url: str) -> str: # Set up cache directory + assert is_valid_url(image_url), "Invalid URL" root = platformdirs.user_cache_dir("textgrad") image_cache_dir = os.path.join(root, "image_cache") os.makedirs(image_cache_dir, exist_ok=True) diff --git a/textgrad/variable.py b/textgrad/variable.py index 17f0b63..7f3af88 100644 --- a/textgrad/variable.py +++ b/textgrad/variable.py @@ -5,12 +5,8 @@ from collections import defaultdict from functools import partial from .config import SingletonBackwardEngine +from .utils.image_utils import is_valid_url from typing import Union -from urllib.parse import urlparse - -def is_valid_url(url): - result = urlparse(url) - return all([result.scheme, result.netloc]) class Variable: def __init__( @@ -45,15 +41,19 @@ def __init__( f"In this case, following predecessors require grad: {_predecessor_requires_grad}") assert type(value) in [str, bytes], "Value must be a string or image (bytes)." - + if value == "" and image_path == "": + raise ValueError("Please provide a value or an image path for the variable") + if value != "" and image_path != "": + raise ValueError("Please provide either a value or an image path for the variable, not both.") if image_path != "": if is_valid_url(image_path): self.value = httpx.get(image_path).content with open(image_path, 'rb') as file: self.value = file.read() - - self.value = value + else: + self.value = value + self.gradients: Set[Variable] = set() self.gradients_context: Dict[Variable, str] = defaultdict(lambda: None) self.grad_fn = None