-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathUtilityFunctions.py
39 lines (29 loc) · 1.28 KB
/
UtilityFunctions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
I_FUNCTION = (lambda x : x)
BN_FILTER = (lambda n: ("BN" not in n))
def enumerateFloatTensor(tensor, indexList = []):
if tensor.shape != torch.Size([]):
ret = []
for ind, t, in enumerate(tensor):
ret = ret + enumerateFloatTensor(t, indexList + [ind])
return ret
else:
return [(indexList, tensor.item())]
def getNBiggestWeights(mod, n = 10, biggest = True, absVal = True, nameFilter = I_FUNCTION):
paramList = []
for paramName, paramTensor in mod.named_parameters():
if nameFilter(paramName):
for indexList, fl in enumerateFloatTensor(paramTensor):
if absVal:
paramList.append((abs(fl), paramName + "_" + str(indexList).replace(" ", "")))
else:
paramList.append((fl, paramName + "_" + str(indexList).replace(" ", "")))
paramList.sort(biggest)
return paramList[:n]
def zeroLaterals(prognet, val = 0.0):
stateDict = prognet.state_dict()
for paramName, paramTensor in stateDict.items():
if "laterals" in paramName:
filledTensor = paramTensor.fill_(val)
stateDict[paramName] = filledTensor
prognet.load_state_dict(stateDict)
#===============================================================================