forked from rawsh/mirrorllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib.py
36 lines (32 loc) · 1.24 KB
/
lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def test():
return "ASDD"
from typing import List, Dict
import torch
def extract_tensors(m: torch.nn.Module) -> Tuple[torch.nn.Module, List[Dict]]:
"""
Remove the tensors from a PyTorch model, convert them to NumPy
arrays, and return the stripped model and tensors.
"""
tensors = []
for _, module in m.named_modules():
# Store the tensors in Python dictionaries
params = {
name: torch.clone(param).detach().numpy()
for name, param in module.named_parameters(recurse=False)
}
buffers = {
name: torch.clone(buf).detach().numpy()
for name, buf in module.named_buffers(recurse=False)
}
tensors.append({"params": params, "buffers": buffers})
# Make a copy of the original model and strip all tensors and
# temporary buffers out of the copy.
m_copy = copy.deepcopy(m)
for _, module in m_copy.named_modules():
for name in (
[name for name, _ in module.named_parameters(recurse=False)]
+ [name for name, _ in module.named_buffers(recurse=False)]):
setattr(module, name, None)
# Make sure the copy is configured for inference.
m_copy.train(False)
return m_copy, tensors