-
Notifications
You must be signed in to change notification settings - Fork 0
/
mean.py
36 lines (29 loc) · 1007 Bytes
/
mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import os
import glob
import cv2
from PIL import Image
import torch
import torch.nn as nn
import torchvision
import glob
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
def mean__std(data_loader):
cnt = 0
mean = torch.empty(3)
std = torch.empty(3)
for data, label in data_loader:
b, c, h, w = data.size()
nb_pixels = b * h * w
sum_ = torch.sum(data, dim=[0, 2, 3])
sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
mean = (cnt * mean + sum_) / (cnt + nb_pixels)
std = (cnt * std + sum_of_square) / (cnt + nb_pixels)
cnt += nb_pixels
return mean, torch.sqrt(std - mean ** 2)
train_data = torchvision.datasets.ImageFolder('b', transform=transforms.Compose([transforms.ToTensor()]))
data_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False, num_workers=4)
mean, std = mean__std(data_loader)
print(mean, std)