Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

siren model added #3

Open
wants to merge 1 commit into
base: br/gaussian_embedding
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,66 +58,66 @@ def forward(self, x):
return torch.cat([torch.sin(x_), torch.cos(x_)], dim=1) # (B, 2F)


# class Siren(nn.Module):
# def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
# first_omega_0=30, hidden_omega_0=30.):
# super().__init__()
class Siren(nn.Module):
def __init__(self, in_features=2, out_features=3, hidden_features = 256, hidden_layers =4, outermost_linear=False,
first_omega_0=30, hidden_omega_0=30.):
super().__init__()

# self.net = []
# self.net.append(SineLayer(in_features, hidden_features,
# is_first=True, omega_0=first_omega_0))
self.net = []
self.net.append(SineLayer(in_features, hidden_features,
is_first=True, omega_0=first_omega_0))

# for i in range(hidden_layers):
# self.net.append(SineLayer(hidden_features, hidden_features,
# is_first=False, omega_0=hidden_omega_0))
for i in range(hidden_layers):
self.net.append(SineLayer(hidden_features, hidden_features,
is_first=False, omega_0=hidden_omega_0))

# if outermost_linear:
# final_linear = nn.Linear(hidden_features, out_features)
if outermost_linear:
final_linear = nn.Linear(hidden_features, out_features)

# with torch.no_grad():
# final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
# np.sqrt(6 / hidden_features) / hidden_omega_0)
with torch.no_grad():
final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / hidden_omega_0,
np.sqrt(6 / hidden_features) / hidden_omega_0)

# self.net.append(final_linear)
# else:
# self.net.append(SineLayer(hidden_features, out_features,
# is_first=False, omega_0=hidden_omega_0))
self.net.append(final_linear)
else:
self.net.append(SineLayer(hidden_features, out_features,
is_first=False, omega_0=hidden_omega_0))

# self.net = nn.Sequential(*self.net)
self.net = nn.Sequential(*self.net)

# def forward(self, x):
# return self.net(x)
def forward(self, x):
return self.net(x)

# class SineLayer(nn.Module):
# # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.
class SineLayer(nn.Module):
# See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

# # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
# # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
# # hyperparameter.
# If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
# nonlinearity. Different signals may require different omega_0 in the first layer - this is a
# hyperparameter.

# # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
# # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)
# If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
# activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

# def __init__(self, in_features, out_features, bias=True,
# is_first=False, omega_0=30):
# super().__init__()
# self.omega_0 = omega_0
# self.is_first = is_first
def __init__(self, in_features, out_features, bias=True,
is_first=False, omega_0=30):
super().__init__()
self.omega_0 = omega_0
self.is_first = is_first

# self.in_features = in_features
# self.linear = nn.Linear(in_features, out_features, bias=bias)
self.in_features = in_features
self.linear = nn.Linear(in_features, out_features, bias=bias)

# self.init_weights()
self.init_weights()

# def init_weights(self):
# with torch.no_grad():
# if self.is_first:
# self.linear.weight.uniform_(-1 / self.in_features,
# 1 / self.in_features)
# else:
# self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
# np.sqrt(6 / self.in_features) / self.omega_0)
def init_weights(self):
with torch.no_grad():
if self.is_first:
self.linear.weight.uniform_(-1 / self.in_features,
1 / self.in_features)
else:
self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
np.sqrt(6 / self.in_features) / self.omega_0)

# def forward(self, x):
# return torch.sin(self.omega_0 * self.linear(x))
def forward(self, x):
return torch.sin(self.omega_0 * self.linear(x))

3 changes: 3 additions & 0 deletions opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def get_opts():
# gaussian embedding scale factor
parser.add_argument('--sc', type=float, default=10., help='Gaussian embedding scale factor')

# omega for siren
parser.add_argument('--omega', type=float, default=30., help='Omega for siren')

# batch size with 4
parser.add_argument('--batch_size', type=int, default=1024, help='Batch size')

Expand Down
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataset import ImageDataset

#models
from models import MLP, PE
from models import MLP, PE, Siren

#optimizer
from torch.optim import Adam
Expand Down Expand Up @@ -45,11 +45,18 @@ def __init__(self, hparams):
self.pe = PE(P)
self.net = MLP(n_input=self.pe.out_dim)

elif hparams.arch == 'siren':
self.net = Siren( hidden_omega_0= hparams.omega,
first_omega_0= hparams.omega,
)

self.loss = MSELoss()

def forward(self, x):
if hparams.arch == 'identity':
return self.net(x)
elif hparams.arch == 'siren':
return self.net(x)
else:
return self.net(self.pe(x))

Expand Down