-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_utils.py
119 lines (104 loc) · 3.57 KB
/
dataset_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
'''
Author: GhMa
Date: 2022-04-07 20:48:43
LastEditors: error: error: git config user.name & please set dead value or install git && error: git config user.email & please set dead value or install git & please set dead value or install git
LastEditTime: 2023-03-31 14:58:30
'''
import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import os
import random
import numpy as np
from tqdm import tqdm
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from typing import List
from pandas.tseries import offsets
from pandas.tseries.frequencies import to_offset
from scipy.ndimage import gaussian_filter1d
class NaturalMovie(torch.utils.data.Dataset):
def __init__(
self,
root,
exp=1, # which experiment (different retina for different exp)
mov=1, # which movie
subset='train',
stim_hist=15,
zscore=False
):
self.root = os.path.expanduser(root)
self.exp = exp
self.mov = mov
self.subset = subset
dataset = np.load(
os.path.join(
root, 'exp_{}_mov_{}_{}_set.npy'.format(exp, mov, subset)
), allow_pickle=True
).item()
self.stimuli = torch.LongTensor(dataset['stimuli']) # frame ids
self.frames = torch.Tensor(dataset['frames'])
if zscore:
print('==> apply zscore norm')
self.frames = (self.frames - self.frames.mean()) / (self.frames.std()+1e-6)
print(self.frames.min(), self.frames.max(), self.frames.mean())
self.response = torch.Tensor(dataset['response'])
self.fr = torch.Tensor(dataset['fr'])
self.stim_history = stim_hist
if stim_hist == 29:
# predictors
for i in range(self.fr.size(0)):
self.fr[i] = torch.Tensor(gaussian_filter1d(self.fr[i], sigma=1))
def __getitem__(self, index):
f_index = torch.LongTensor(
range(self.stimuli[index][0], self.stimuli[index][1])
)
stimuli = self.frames[f_index]
response = self.response[index, self.stim_history:, :]
fr = self.fr[index, self.stim_history:, :]
return stimuli, response, fr
def __len__(self):
return self.stimuli.size(0)
def prepare_natural_movie(args, stim_hist=15, shuffle_train=True, zscore=False):
root = '/home/ghma/data/natural_mov_RGC_responses'
trainset = NaturalMovie(
root=root, subset='train', exp=args.exp, mov=args.mov,
stim_hist=stim_hist,
zscore=zscore
)
testset = NaturalMovie(
root=root, subset='test', exp=args.exp, mov=args.mov,
stim_hist=stim_hist,
zscore=zscore
)
trainloader = torch.utils.data.DataLoader(
dataset=trainset,
batch_size=args.minibatch,
shuffle=shuffle_train,
drop_last=False,
num_workers=4,
)
testloader = torch.utils.data.DataLoader(
dataset=testset,
batch_size=args.minibatch,
shuffle=False, # do not shuffle the test dataset
drop_last=False,
num_workers=4,
)
vis_trainloader = torch.utils.data.DataLoader(
dataset=trainset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=4,
)
vis_testloader = torch.utils.data.DataLoader(
dataset=testset,
batch_size=1,
shuffle=False,
drop_last=False,
num_workers=4,
)
print('data loaded')
return trainloader, testloader, trainset.stim_history, \
vis_trainloader, vis_testloader