-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
57 lines (47 loc) · 1.47 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch import nn
class MLP(nn.Module):
'''
MLP with 4 hidden layers
'''
def __init__(self, n_layers=4, n_hidden=256, n_input=2, n_output=3):
super(MLP, self).__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_input = n_input
self.n_output = n_output
self.layers = [nn.Linear(self.n_input, self.n_hidden), nn.ReLU(True)]
for i in range(self.n_layers - 1):
if i != self.n_layers - 2:
self.layers.append(nn.Linear(self.n_hidden, self.n_hidden))
self.layers.append(nn.ReLU(True))
else:
self.layers.append(nn.Linear(self.n_hidden, self.n_output))
self.layers.append(nn.Sigmoid())
# self.net = nn.ModuleList(self.layers)
self.net = nn.Sequential(*self.layers)
def forward(self, x):
"""
x: (B, 2) # pixel uv(normalized)
"""
return self.net(x) # (B, 3) rgb
class PE(nn.Module):
"""
perform positional encoding
"""
def __init__(self, P):
"""
P:(2, F) encoding matrix
"""
super(PE, self).__init__()
# self.P = P
self.register_buffer('P', P)
@property
def out_dim(self):
return self.P.shape[1] * 2
def forward(self, x):
"""
x: (B, 2)
"""
x_ = x @ self.P # (B, F)
return torch.cat([torch.sin(x_), torch.cos(x_)], dim=1) # (B, 2F)