-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_6d.py
107 lines (84 loc) · 3.44 KB
/
train_6d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
import ast
from model_6d import PointNet
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
class PointCloudDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
# Convert string representation of points to numpy arrays
self.points = self.data['feature'].apply(lambda x: np.array(ast.literal_eval(x.replace('];[', '], ['))))
# Encode labels
self.label_encoder = LabelEncoder()
self.labels = self.label_encoder.fit_transform(self.data['label'])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
points = torch.FloatTensor(self.points[idx])
label = torch.LongTensor([self.labels[idx]])[0]
return points, label
def train_pointnet(model, train_loader, num_epochs=100, learning_rate=0.001, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
with tqdm(total=num_epochs*len(train_loader), desc="Processing files") as pbar:
for epoch in range(num_epochs):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (points, labels) in enumerate(train_loader):
optimizer.zero_grad()
# Move to device and ensure correct shape (B, C, N)
points = points.to(device)
labels = labels.to(device)
points = points.permute(0, 2, 1) # (B, N, 3) -> (B, 3, N)
outputs, t6, t64 = model(points)
loss = model.loss(outputs, labels, t6, t64)
loss.backward()
optimizer.step()
# Statistics
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Tensorboard logging
writer.add_scalar("Loss/train", loss.item(), epoch)
writer.add_scalar("Accuracy/train", correct/total, epoch)
# Print progress
pbar.update(1)
pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Acc: {100.*correct/total:.2f}%")
return model
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Create dataset and dataloader
print('Loading dataset...')
dataset = PointCloudDataset('data_generation/train_6d.csv')
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
print('Dataset loaded!')
# Initialize model with correct number of classes
num_classes = len(dataset.label_encoder.classes_)
model = PointNet(num_classes=num_classes)
print('Model initialized!')
# Train model
trained_model = train_pointnet(
model=model,
train_loader=train_loader,
num_epochs=100,
learning_rate=0.001,
device=device
)
writer.flush()
# Save model
torch.save(trained_model.state_dict(), 'pointnet_model.pth')
print("Training completed and model saved!")
if __name__ == "__main__":
main()