Skip to content

Commit

Permalink
check for variable values
Browse files Browse the repository at this point in the history
  • Loading branch information
mertyg committed Jul 7, 2024
1 parent 619b6ec commit 177b0a6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
9 changes: 8 additions & 1 deletion textgrad/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
from urllib.parse import urlparse
from typing import Union
import platformdirs
import base64

from urllib.parse import urlparse

def is_valid_url(url):
result = urlparse(url)
return all([result.scheme, result.netloc])


def download_and_cache_image(image_url: str) -> str:
# Set up cache directory
assert is_valid_url(image_url), "Invalid URL"
root = platformdirs.user_cache_dir("textgrad")
image_cache_dir = os.path.join(root, "image_cache")
os.makedirs(image_cache_dir, exist_ok=True)
Expand Down
16 changes: 8 additions & 8 deletions textgrad/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
from collections import defaultdict
from functools import partial
from .config import SingletonBackwardEngine
from .utils.image_utils import is_valid_url
from typing import Union
from urllib.parse import urlparse

def is_valid_url(url):
result = urlparse(url)
return all([result.scheme, result.netloc])

class Variable:
def __init__(
Expand Down Expand Up @@ -45,15 +41,19 @@ def __init__(
f"In this case, following predecessors require grad: {_predecessor_requires_grad}")

assert type(value) in [str, bytes], "Value must be a string or image (bytes)."

if value == "" and image_path == "":
raise ValueError("Please provide a value or an image path for the variable")
if value != "" and image_path != "":
raise ValueError("Please provide either a value or an image path for the variable, not both.")

if image_path != "":
if is_valid_url(image_path):
self.value = httpx.get(image_path).content
with open(image_path, 'rb') as file:
self.value = file.read()

self.value = value
else:
self.value = value

self.gradients: Set[Variable] = set()
self.gradients_context: Dict[Variable, str] = defaultdict(lambda: None)
self.grad_fn = None
Expand Down

0 comments on commit 177b0a6

Please sign in to comment.