-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathpruning_functions.py
36 lines (28 loc) · 1.25 KB
/
pruning_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import utils
def no_pruning(model, dataset, pruning_every=100):
return
def pruning_occupancy(model, dataset, threshold=-10):
model_input = dataset.get_eval_samples(1)
print("Pruning: loading data to cuda...")
tmp = {}
for key, value in model_input.items():
if isinstance(value, torch.Tensor):
tmp.update({key: value[None, ...].cuda()})
else:
tmp.update({key: value})
model_input = tmp
print("Pruning: evaluating occupancy...")
pred_occupancy = utils.process_batch_in_chunks(model_input, model)['model_out']['output']
pred_occupancy = torch.max(pred_occupancy, dim=-2).values.squeeze()
pred_occupancy_idx = model_input['coord_octant_idx'].squeeze()
print("Pruning: computing mean and freezing empty octants")
active_octants = dataset.octtree.get_active_octants()
frozen_octants = 0
for idx, octant in enumerate(active_octants):
max_prediction = torch.max(pred_occupancy[pred_occupancy_idx == idx])
if max_prediction < threshold and octant.err < 1e-3: # Prune if model is confident that everything is empty
octant.frozen = True
frozen_octants += 1
print(f"Pruning: Froze {frozen_octants} octants.")
dataset.synchronize()