Skip to content

Commit

Permalink
Eliminate .data access for parameters as much as possible (pytorch#767)
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwakftw authored May 21, 2020
1 parent 31643b2 commit 1877b87
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
6 changes: 3 additions & 3 deletions dcgan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
torch.nn.init.normal_(m.weight, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions distributed/rpc/pipeline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def __init__(self, device, *args, **kwargs):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)

def forward(self, x_rref):
x = x_rref.to_here().to(self.device)
Expand Down
6 changes: 3 additions & 3 deletions distributed/rpc/rnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp).cuda()
self.encoder.weight.data.uniform_(-0.1, 0.1)
nn.init.uniform_(self.encoder.weight, -0.1, 0.1)

def forward(self, input):
return self.drop(self.encoder(input.cuda())).cpu()
Expand All @@ -56,8 +56,8 @@ def __init__(self, ntoken, nhid, dropout):
super(Decoder, self).__init__()
self.drop = nn.Dropout(dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-0.1, 0.1)
nn.init.zeros_(self.decoder.bias)
nn.init.uniform_(self.decoder.weight, -0.1, 0.1)

def forward(self, output):
return self.decoder(self.drop(output))
Expand Down
2 changes: 1 addition & 1 deletion regression/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_batch(batch_size=32):

# Apply gradients
for param in fc.parameters():
param.data.add_(-0.1 * param.grad.data)
param.add_(-0.1 * param.grad)

# Stop criterion
if loss < 1e-3:
Expand Down
2 changes: 1 addition & 1 deletion word_language_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def train():
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
for p in model.parameters():
p.data.add_(-lr, p.grad.data)
p.add_(-lr, p.grad)

total_loss += loss.item()

Expand Down
12 changes: 6 additions & 6 deletions word_language_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weigh

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, input, hidden):
emb = self.drop(self.encoder(input))
Expand Down Expand Up @@ -132,9 +132,9 @@ def _generate_square_subsequent_mask(self, sz):

def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
nn.init.uniform_(self.encoder.weight, -initrange, initrange)
nn.init.zeros_(self.decoder)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)

def forward(self, src, has_mask=True):
if has_mask:
Expand Down

0 comments on commit 1877b87

Please sign in to comment.