Skip to content

Commit

Permalink
fix bugs highlited in pre-commit ci
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliotMaitre committed Mar 25, 2024
1 parent 75451cc commit 4bf9b57
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,6 @@
from happypose.toolbox.datasets.scene_dataset_wrappers import remove_invisible_objects


def collate_fn(batch):
rgbs, targets = zip(*batch)
# Stack the rgbs and convert to a tensor
rgbs = torch.stack(rgbs, dim=0)

# Initialize the target dictionary
target = {
"boxes": [],
"labels": [],
"masks": [],
"image_id": [],
"area": [],
"iscrowd": [],
}

# Concatenate the target data for each image in the batch
for t in targets:
target["boxes"].append(t["boxes"])
target["labels"].append(t["labels"])
target["masks"].append(t["masks"])
target["image_id"].append(t["image_id"])
target["area"].append(t["area"])
target["iscrowd"].append(t["iscrowd"])

# Stack the target data and convert to tensors
for key in target.keys():
target[key] = torch.cat(target[key], dim=0)

# Return the batch data as a dictionary
return {"rgbs": rgbs, "targets": target}


# TODO : Double check on types and add format documentation
@dataclass
class DetectionData:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@ def __getitem__(self, idx):
)
object_datas.append(obj_data)

state = {
"camera": cam,
"objects": objects,
"frame_info": self.frame_index.iloc[idx].to_dict(),
}
image_infos = ObservationInfos(
scene_id=self.frame_index.iloc[idx].to_dict()["scene_id"],
view_id=self.frame_index.iloc[idx].to_dict()["view_id"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import os
import warnings

import numpy as np
from colorama import Fore, Style
Expand All @@ -9,12 +10,9 @@

logger = get_logger(__name__)


import warnings

# TODO : Fix warnings
warnings.filterwarnings("ignore")


def make_cfg(args):
cfg = argparse.ArgumentParser("").parse_args([])
if args.config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

from happypose.pose_estimators.cosypose.cosypose.config import EXP_DIR
from happypose.pose_estimators.cosypose.cosypose.datasets.detection_dataset import (
DetectionDataset,
collate_fn,
DetectionDataset
)
from happypose.pose_estimators.cosypose.cosypose.integrated.detector import Detector

Expand Down Expand Up @@ -51,11 +50,9 @@
cudnn.benchmark = True
logger = get_logger(__name__)


def collate_fn(batch):
return tuple(zip(*batch))


def make_eval_configs(args, model_training, epoch):
model = model_training.module
model.config = args
Expand Down Expand Up @@ -335,6 +332,7 @@ def train_epoch():
if epoch < args.n_epochs_warmup:
lr_scheduler_warmup.step()
t = time.time()

if epoch >= args.n_epochs_warmup:
lr_scheduler.step()

Expand All @@ -345,6 +343,7 @@ def validation():
loss = h(data=sample, meters=meters_val)
meters_val["loss_total"].add(loss.item())


train_epoch()
if epoch % args.val_epoch_interval == 0:
validation()
Expand Down

0 comments on commit 4bf9b57

Please sign in to comment.