-
Notifications
You must be signed in to change notification settings - Fork 2
/
init_weights.py
42 lines (33 loc) · 1.38 KB
/
init_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright 2023 LL.
import torch.nn as nn
def kaiming_init(
module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
):
assert distribution in ["uniform", "normal"]
if distribution == "uniform":
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity
)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity
)
if hasattr(module, "bias") and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution="normal"):
assert distribution in ["uniform", "normal"]
if distribution == "uniform":
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, "bias") and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, "bias") and module.bias is not None:
nn.init.constant_(module.bias, bias)
def constant_init(module, val, bias=0):
if hasattr(module, "weight") and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, "bias") and module.bias is not None:
nn.init.constant_(module.bias, bias)