diff --git a/finetune_instruct_pix2pix.py b/finetune_instruct_pix2pix.py index 6d353ae..9af9556 100644 --- a/finetune_instruct_pix2pix.py +++ b/finetune_instruct_pix2pix.py @@ -434,9 +434,16 @@ def convert_to_np(image, resolution): image = image.convert("RGB").resize((resolution, resolution)) return np.array(image).transpose(2, 0, 1) +def load_image(source): + if source.startswith('http'): + # Download image from URL + response = requests.get(source, stream=True) + response.raise_for_status() + image = PIL.Image.open(response.raw) + else: + # Open image from local path + image = PIL.Image.open(source) -def download_image(url): - image = PIL.Image.open(requests.get(url, stream=True).raw) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image @@ -462,7 +469,7 @@ def main(): gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, - logging_dir=logging_dir, + project_dir=logging_dir, project_config=accelerator_project_config, ) @@ -1075,7 +1082,7 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - original_image = download_image(args.val_image_url) + load_image(args.val_image_url) edited_images = [] with torch.autocast( str(accelerator.device),