-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathus_datasets.py
executable file
·47 lines (38 loc) · 1.74 KB
/
us_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import glob
import random
import os
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch
from pathlib import Path
class USDataset(Dataset):
def __init__(self, transforms_=None, unaligned=True, mode='train', medium='alcohol'):
self.transform = transforms.Compose(transforms_)
self.unaligned = unaligned
self.mode = mode
# id_range = range(0, 446) if medium == 'alcohol' else range(446, 1000)
id_range = range(446, 1000)
# self.files_A = glob.glob('/mnt/airsfs6/aiiv_processed_data/US_ID/RB00*/*.png')
self.files_A = []
for i in id_range:
self.files_A += glob.glob('/mnt/airsfs6/aiiv_processed_data/US_ID/RB%05d/*.png' % i)[::30]
if self.mode == 'train':
self.files_B = glob.glob('/mnt/airsfs6/aiiv_processed_data/US_ID/AE*/*.png')[::15]
def __getitem__(self, index):
if self.mode == 'train':
img_A = Image.open(self.files_A[index % len(self.files_A)]).crop((0, 0, 368, 368)).resize((512, 512))
item_A = self.transform(img_A) * 2 - 1
img_B = Image.open(self.files_B[index % len(self.files_B)]).crop((0, 0, 512, 512))
item_B = self.transform(img_B) * 2 - 1
return {'A': item_A, 'B': item_B}
else:
img_A = Image.open(self.files_A[index % len(self.files_A)]).crop((0, 0, 368, 368)).resize((512, 512))
item_A = self.transform(img_A) * 2 - 1
return {'A': item_A, 'fn': self.files_A[index % len(self.files_A)]}
def __len__(self):
if self.mode == 'train':
return max(len(self.files_A), len(self.files_B))
else:
return len(self.files_A)