Skip to content

Commit

Permalink
Merge pull request #74 from okotaku/fix/test_lint
Browse files Browse the repository at this point in the history
[Enhance] Fix tests lint
  • Loading branch information
okotaku authored Oct 16, 2023
2 parents d817c76 + 96747fb commit b555fa9
Show file tree
Hide file tree
Showing 25 changed files with 376 additions and 459 deletions.
6 changes: 3 additions & 3 deletions tests/test_datasets/test_hf_controlnet_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def test_dataset_from_local(self):

data = dataset[0]
assert data["text"] == "a dog"
assert isinstance(data["img"], Image.Image)
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 400
assert isinstance(data["img"], Image.Image)
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 400
assert isinstance(data["condition_img"], Image.Image)
self.assertIsInstance(data["condition_img"], Image.Image)
assert data["condition_img"].width == 400
14 changes: 7 additions & 7 deletions tests/test_datasets/test_hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_dataset_from_local(self):

data = dataset[0]
assert data["text"] == "a dog"
assert isinstance(data["img"], Image.Image)
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 400

dataset = HFDataset(
Expand All @@ -25,7 +25,7 @@ def test_dataset_from_local(self):

data = dataset[0]
assert data["text"] == "a cat"
assert isinstance(data["img"], Image.Image)
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 400


Expand All @@ -41,9 +41,9 @@ def test_dataset_from_local(self):

data = dataset[0]
assert "text" not in data
assert isinstance(data["prompt_embeds"], list)
assert isinstance(data["pooled_prompt_embeds"], list)
assert np.array(data["prompt_embeds"]).shape == (77, 64)
assert np.array(data["pooled_prompt_embeds"]).shape == (32, )
assert isinstance(data["img"], Image.Image)
self.assertEqual(type(data["prompt_embeds"]), list)
self.assertEqual(type(data["pooled_prompt_embeds"]), list)
self.assertEqual(np.array(data["prompt_embeds"]).shape, (77, 64))
self.assertEqual(np.array(data["pooled_prompt_embeds"]).shape, (32, ))
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 400
29 changes: 13 additions & 16 deletions tests/test_datasets/test_hf_dreambooth_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,35 @@ def test_dataset(self):

data = dataset[0]
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 1815

def test_dataset_with_class_image(self):
dataset = HFDreamBoothDataset(
dataset="diffusers/dog-example",
instance_prompt="a photo of sks dog",
class_prompt="a photo of dog",
class_image_config={
"model": "diffusers/tiny-stable-diffusion-torch",
"data_dir": "temp_dir/class_image",
"num_images": 1,
"device": "cpu",
"recreate_class_images": True,
},
class_image_config=dict(
model="diffusers/tiny-stable-diffusion-torch",
data_dir="temp_dir/class_image",
num_images=1,
device="cpu",
recreate_class_images=True,
),
pipeline=[
{
"type": "PackInputs",
"skip_to_tensor_key": ["img", "text"],
},
dict(type="PackInputs", skip_to_tensor_key=["img", "text"]),
])
assert len(dataset) == 5
assert len(dataset.class_images) == 1

data = dataset[0]
assert data["inputs"]["text"] == "a photo of sks dog"
assert isinstance(data["inputs"]["img"], Image.Image)
self.assertIsInstance(data["inputs"]["img"], Image.Image)
assert data["inputs"]["img"].width == 1815

assert data["inputs"]["result_class_image"]["text"] == "a photo of dog"
assert isinstance(data["inputs"]["result_class_image"]["img"],
Image.Image)
self.assertIsInstance(data["inputs"]["result_class_image"]["img"],
Image.Image)
assert data["inputs"]["result_class_image"]["img"].width == 128
shutil.rmtree("temp_dir")

Expand All @@ -70,5 +67,5 @@ def test_dataset_from_local(self):

data = dataset[0]
assert data["text"] == "a photo of sks dog"
assert isinstance(data["img"], Image.Image)
self.assertIsInstance(data["img"], Image.Image)
assert data["img"].width == 400
18 changes: 9 additions & 9 deletions tests/test_datasets/test_hf_esd_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def test_dataset_from_local(self):

data = dataset[0]
assert "img" not in data
assert data["text"] == "Van Gogh"
assert type(data["prompt_embeds"]) == torch.Tensor
assert type(data["pooled_prompt_embeds"]) == torch.Tensor
assert type(data["null_prompt_embeds"]) == torch.Tensor
assert type(data["null_pooled_prompt_embeds"]) == torch.Tensor
assert data["prompt_embeds"].shape == (77, 64)
assert data["pooled_prompt_embeds"].shape == (32, )
assert data["null_prompt_embeds"].shape == (77, 64)
assert data["null_pooled_prompt_embeds"].shape == (32, )
self.assertEqual(data["text"], "Van Gogh")
self.assertEqual(type(data["prompt_embeds"]), torch.Tensor)
self.assertEqual(type(data["pooled_prompt_embeds"]), torch.Tensor)
self.assertEqual(type(data["null_prompt_embeds"]), torch.Tensor)
self.assertEqual(type(data["null_pooled_prompt_embeds"]), torch.Tensor)
self.assertEqual(data["prompt_embeds"].shape, (77, 64))
self.assertEqual(data["pooled_prompt_embeds"].shape, (32, ))
self.assertEqual(data["null_prompt_embeds"].shape, (77, 64))
self.assertEqual(data["null_pooled_prompt_embeds"].shape, (32, ))
32 changes: 18 additions & 14 deletions tests/test_datasets/test_samplers/test_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ def __len__(self):
return self.length

def __getitem__(self, idx):
results = {
"img": torch.zeros((3, self.shapes[idx][0], self.shapes[idx][1])),
"aspect_ratio": self.shapes[idx][0] / self.shapes[idx][1],
}
return {"inputs": results}
results = dict(
img=torch.zeros((3, self.shapes[idx][0], self.shapes[idx][1])),
aspect_ratio=self.shapes[idx][0] / self.shapes[idx][1])
return dict(inputs=results)


class TestAspectRatioBatchSampler(TestCase):
Expand All @@ -50,36 +49,41 @@ def test_divisible_batch(self):
batch_size = 5
batch_sampler = AspectRatioBatchSampler(
self.sampler, batch_size=batch_size, drop_last=True)
assert len(batch_sampler) == self.length / 2 // batch_size * 2
self.assertEqual(
len(batch_sampler), (self.length / 2 // batch_size) * 2)
for batch_idxs in batch_sampler:
assert len(batch_idxs) == batch_size
self.assertEqual(len(batch_idxs), batch_size)
batch = [
self.dataset[idx]["inputs"]["aspect_ratio"]
for idx in batch_idxs
]
for i in range(1, batch_size):
assert batch[0] == batch[i]
self.assertEqual(batch[0], batch[i])

def test_indivisible_batch(self):
batch_size = 7
batch_sampler = AspectRatioBatchSampler(
self.sampler, batch_size=batch_size, drop_last=True)
all_batch_idxs = list(batch_sampler)
assert len(batch_sampler) == self.length / 2 // batch_size * 2
assert len(all_batch_idxs) == self.length / 2 // batch_size * 2
self.assertEqual(
len(batch_sampler), (self.length / 2 // batch_size) * 2)
self.assertEqual(
len(all_batch_idxs), (self.length / 2 // batch_size) * 2)

batch_sampler = AspectRatioBatchSampler(
self.sampler, batch_size=batch_size, drop_last=False)
all_batch_idxs = list(batch_sampler)
assert len(batch_sampler) == self.length / 2 // batch_size * 2 + 2
assert len(all_batch_idxs) == self.length / 2 // batch_size * 2 + 2
self.assertEqual(
len(batch_sampler), (self.length / 2 // batch_size) * 2 + 2)
self.assertEqual(
len(all_batch_idxs), (self.length / 2 // batch_size) * 2 + 2)

# the last batch may not have the same aspect ratio
for batch_idxs in all_batch_idxs[:-2]:
assert len(batch_idxs) == batch_size
self.assertEqual(len(batch_idxs), batch_size)
batch = [
self.dataset[idx]["inputs"]["aspect_ratio"]
for idx in batch_idxs
]
for i in range(1, batch_size):
assert batch[0] == batch[i]
self.assertEqual(batch[0], batch[i])
14 changes: 7 additions & 7 deletions tests/test_datasets/test_transforms/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class TestPackInputs(unittest.TestCase):
def test_transform(self):
data = {"dummy": 1, "img": torch.zeros((3, 32, 32)), "text": "a"}

cfg = {"type": "PackInputs", "input_keys": ["img", "text"]}
cfg = dict(type="PackInputs", input_keys=["img", "text"])
transform = TRANSFORMS.build(cfg)
results = transform(copy.deepcopy(data))
assert "inputs" in results
self.assertIn("inputs", results)

assert "img" in results["inputs"]
assert isinstance(results["inputs"]["img"], torch.Tensor)
assert "text" in results["inputs"]
assert isinstance(results["inputs"]["text"], str)
assert "dummy" not in results["inputs"]
self.assertIn("img", results["inputs"])
self.assertIsInstance(results["inputs"]["img"], torch.Tensor)
self.assertIn("text", results["inputs"])
self.assertIsInstance(results["inputs"]["text"], str)
self.assertNotIn("dummy", results["inputs"])
Loading

0 comments on commit b555fa9

Please sign in to comment.