-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
126 lines (98 loc) · 3.34 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from dataset import create_wall_dataloader
from evaluator import ProbingEvaluator
import torch
from best_model import *
import glob
def get_device():
"""Check for GPU availability."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
return device
def load_data(device):
data_path = "/scratch/DL24FA"
probe_train_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_normal/train",
probing=True,
device=device,
train=True,
)
probe_val_normal_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_normal/val",
probing=True,
device=device,
train=False,
)
probe_val_wall_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_wall/val",
probing=True,
device=device,
train=False,
)
probe_val_wall_other_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_wall_other/val",
probing=True,
device=device,
train=False,
)
probe_val_ds = {
"normal": probe_val_normal_ds,
"wall": probe_val_wall_ds,
"wall_other": probe_val_wall_other_ds,
}
return probe_train_ds, probe_val_ds
def load_expert_data(device):
data_path = "/scratch/DL24FA"
probe_train_expert_ds = create_wall_dataloader(
data_path=f"{data_path}/probe_expert/train",
probing=True,
device=device,
train=True,
)
probe_val_expert_ds = {
"expert": create_wall_dataloader(
data_path=f"{data_path}/probe_expert/val",
probing=True,
device=device,
train=False,
)
}
return probe_train_expert_ds, probe_val_expert_ds
def load_model():
"""Load or initialize the model."""
# TODO: Replace MockModel with your trained model
device = (
'cuda' if torch.cuda.is_available()
else 'mps' if torch.backends.mps.is_available()
else 'cpu'
)
state_dict_path = "/scratch/fc1132/JEPA_world_model/encoder_outputs/trained_recurrent_jepa_Z_tuned.pth"
state_dim = 256
action_dim = 2
hidden_dim = 128
ema_rate = 0.99
cnn_channels = 64
model = JEPA(state_dim=state_dim, action_dim=action_dim, hidden_dim=hidden_dim, ema_rate=ema_rate, cnn_channels=cnn_channels).to(device)
state_dict = torch.load(state_dict_path, map_location=device)
new_state_dict = {key.replace("_orig_mod.", ""): value for key, value in state_dict.items()}
model.load_state_dict(new_state_dict)
model.eval()
return model
def evaluate_model(device, model, probe_train_ds, probe_val_ds):
evaluator = ProbingEvaluator(
device=device,
model=model,
probe_train_ds=probe_train_ds,
probe_val_ds=probe_val_ds,
quick_debug=False,
)
prober = evaluator.train_pred_prober()
avg_losses = evaluator.evaluate_all(prober=prober)
for probe_attr, loss in avg_losses.items():
print(f"{probe_attr} loss: {loss}")
if __name__ == "__main__":
device = get_device()
model = load_model()
probe_train_ds, probe_val_ds = load_data(device)
evaluate_model(device, model, probe_train_ds, probe_val_ds)
probe_train_expert_ds, probe_val_expert_ds = load_expert_data(device)
evaluate_model(device, model, probe_train_expert_ds, probe_val_expert_ds)