Skip to content

Commit

Permalink
merging?
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliotMaitre committed Mar 25, 2024
2 parents 5c00e07 + 400d14d commit 02b71cd
Show file tree
Hide file tree
Showing 22 changed files with 209 additions and 318 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -28,8 +27,15 @@
SceneObservationAugmentation as SceneObsAug,
)

# HappyPose
from happypose.toolbox.datasets.scene_dataset import (
IterableSceneDataset,
SceneDataset,
SceneObservation,
)
from happypose.toolbox.datasets.scene_dataset_wrappers import remove_invisible_objects


def collate_fn(batch):
rgbs, targets = zip(*batch)
# Stack the rgbs and convert to a tensor
Expand Down Expand Up @@ -61,7 +67,8 @@ def collate_fn(batch):
# Return the batch data as a dictionary
return {"rgbs": rgbs, "targets": target}

#TODO : Double check on types and add format documentation

# TODO : Double check on types and add format documentation
@dataclass
class DetectionData:
"""rgb: (h, w, 3) uint8
Expand Down Expand Up @@ -104,6 +111,7 @@ def pin_memory(self) -> "BatchDetectionData":
self.depths = self.depths.pin_memory()
return self


class DetectionDataset(torch.utils.data.IterableDataset):
def __init__(
self,
Expand Down Expand Up @@ -137,27 +145,26 @@ def __init__(
SceneObsAug(PillowSharpness(factor_interval=(0.0, 50.0)), p=0.3),
SceneObsAug(PillowContrast(factor_interval=(0.2, 50.0)), p=0.3),
SceneObsAug(PillowBrightness(factor_interval=(0.1, 6.0)), p=0.5),
SceneObsAug(PillowColor(factor_interval=(0.0, 20.0)), p=0.3),]
SceneObsAug(PillowColor(factor_interval=(0.0, 20.0)), p=0.3),
]
)

]

self.label_to_category_id = label_to_category_id
self.min_area = min_area

def make_data_from_obs(self, obs: SceneObservation, idx):

obs = remove_invisible_objects(obs)

obs = self.resize_augmentation(obs)

for aug in self.background_augmentations:
obs = aug(obs)

if self.rgb_augmentations and random.random() < 0.8:
for aug in self.rgb_augmentations:
obs = aug(obs)

assert obs.object_datas is not None
assert obs.rgb is not None
assert obs.camera_data is not None
Expand All @@ -177,13 +184,13 @@ def make_data_from_obs(self, obs: SceneObservation, idx):
if obs.binary_masks is not None:
binary_mask = torch.tensor(obs.binary_masks[obj_data.unique_id]).float()
masks.append(binary_mask)

if obs.segmentation is not None:
binary_mask = np.zeros_like(obs.segmentation, dtype=np.bool_)
binary_mask[obs.segmentation == obj_data.unique_id] = 1
binary_mask = torch.as_tensor(binary_mask).float()
masks.append(binary_mask)

masks = np.array(masks)
masks = masks == obj_ids[:, None, None]

Expand Down Expand Up @@ -216,8 +223,8 @@ def __getitem__(self, index: int):
assert isinstance(self.scene_ds, SceneDataset)
obs = self.scene_ds[index]
return self.make_data_from_obs(obs, index)
# def find_valid_data(self, iterator: Iterator[SceneObservation]) -> PoseData:

# def find_valid_data(self, iterator: Iterator[SceneObservation]) -> PoseData:
def find_valid_data(self, iterator):
n_attempts = 0
for idx, obs in enumerate(iterator):
Expand All @@ -228,7 +235,7 @@ def find_valid_data(self, iterator):
if n_attempts > 200:
msg = "Cannot find valid image in the dataset"
raise ValueError(msg)

def __iter__(self):
assert isinstance(self.scene_ds, IterableSceneDataset)
iterator = iter(self.scene_ds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
CameraData,
ObjectData,
ObservationInfos,
SceneDataset,
SceneObservation,
)
from happypose.toolbox.lib3d.transform import Transform
Expand Down Expand Up @@ -82,33 +81,33 @@ def __getitem__(self, idx):
if obj["id_in_segm"] in mask_uniqs:
obj["bbox"] = dets_gt[obj["id_in_segm"]].numpy()
obj_data = ObjectData(
label= obj['name'],
TWO=Transform(obj['TWO']),
unique_id=obj['body_id'],
bbox_modal=obj['bbox']
label=obj["name"],
TWO=Transform(obj["TWO"]),
unique_id=obj["body_id"],
bbox_modal=obj["bbox"],
)
else:
obj_data = ObjectData(
label= obj['name'],
TWO=Transform(obj['TWO']),
unique_id=obj['body_id']
label=obj["name"],
TWO=Transform(obj["TWO"]),
unique_id=obj["body_id"],
)
object_datas.append(obj_data)

state = {
"camera": cam,
"objects": objects,
"frame_info": self.frame_index.iloc[idx].to_dict(),
}
image_infos = ObservationInfos(
scene_id=self.frame_index.iloc[idx].to_dict()['scene_id'],
view_id=self.frame_index.iloc[idx].to_dict()['view_id'])

scene_id=self.frame_index.iloc[idx].to_dict()["scene_id"],
view_id=self.frame_index.iloc[idx].to_dict()["view_id"],
)

cam = CameraData(
K=cam["K"],
resolution=cam["resolution"],
TWC=Transform(cam["TWC"]))

K=cam["K"], resolution=cam["resolution"], TWC=Transform(cam["TWC"])
)

observation = SceneObservation(
rgb=rgb.numpy().astype(np.uint8),
segmentation=mask.numpy().astype(np.uint32),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,18 @@ def run_inference_pipeline(
"""
if self.inference_cfg['detection_type'] == "gt":
if self.inference_cfg["detection_type"] == "gt":
detections = gt_detections
run_detector = False
elif self.inference_cfg['detection_type'] == "detector":
elif self.inference_cfg["detection_type"] == "detector":
detections = None
run_detector = True
else:
msg = f"Unknown detection type {self.inference_cfg['detection_type']}"
raise ValueError(msg)

coarse_estimates = None
if self.inference_cfg['coarse_estimation_type'] == "external":
if self.inference_cfg["coarse_estimation_type"] == "external":
# TODO (ylabbe): This is hacky, clean this for modelnet eval.
coarse_estimates = initial_estimates
coarse_estimates = happypose.toolbox.inference.utils.add_instance_id(
Expand Down Expand Up @@ -144,7 +144,7 @@ def run_inference_pipeline(
"coarse": extra_data["coarse"]["preds"],
}

if self.inference_cfg['run_depth_refiner']:
if self.inference_cfg["run_depth_refiner"]:
all_preds["depth_refiner"] = extra_data["depth_refiner"]["preds"]

# Remove any mask tensors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
def run_pred_eval(pred_runner, pred_kwargs, eval_runner, eval_preds=None):
all_predictions = {}
for pred_prefix, pred_kwargs_n in pred_kwargs.items():
pose_predictor = pred_kwargs_n['pose_predictor']
pose_predictor = pred_kwargs_n["pose_predictor"]
preds = pred_runner.get_predictions(pose_predictor)
for preds_name, preds_n in preds.items():
all_predictions[f"{pred_prefix}/{preds_name}"] = preds_n
Expand Down
2 changes: 1 addition & 1 deletion happypose/pose_estimators/cosypose/cosypose/models/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, images, K, labels, TCO, n_iterations=1):
K_crop=K_crop,
boxes_rend=boxes_rend,
boxes_crop=boxes_crop,
model_outputs=model_outputs
model_outputs=model_outputs,
)

TCO_input = TCO_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
RESULTS_DIR,
)
from happypose.pose_estimators.cosypose.cosypose.datasets.bop import remap_bop_targets
from happypose.toolbox.datasets.datasets_cfg import (
make_object_dataset,
make_scene_dataset,
)
from happypose.pose_estimators.cosypose.cosypose.datasets.samplers import ListSampler
from happypose.pose_estimators.cosypose.cosypose.datasets.wrappers.multiview_wrapper import (
MultiViewWrapper,
Expand Down Expand Up @@ -61,6 +57,10 @@
init_distributed_mode,
)
from happypose.pose_estimators.cosypose.cosypose.utils.logging import get_logger
from happypose.toolbox.datasets.datasets_cfg import (
make_object_dataset,
make_scene_dataset,
)
from happypose.toolbox.lib3d.transform import Transform
from happypose.toolbox.renderer.bullet_batch_renderer import BulletBatchRenderer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def run_inference(
rgb, depth, camera_data = load_observation_example(example_dir, load_depth=True)
# TODO: cosypose forward does not work if depth is loaded detection
# contrary to megapose
observation = ObservationTensor.from_numpy(rgb, depth=None, K=camera_data.K).to(device)
observation = ObservationTensor.from_numpy(rgb, depth=None, K=camera_data.K).to(
device
)

# Load models
pose_estimator = setup_pose_estimator(args.dataset, object_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@


import warnings

warnings.filterwarnings("ignore")


def make_cfg(args):
cfg = argparse.ArgumentParser("").parse_args([])
if args.config:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def h_pose(model, mesh_db, data, meters, cfg, n_iterations=1, input_generator="f
TCO_pred = iter_outputs.TCO_output
model_outputs = iter_outputs.model_outputs


if cfg.loss_disentangled:
if cfg.n_pose_dims == 9:
loss_fn = loss_refiner_CO_disentangled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
import yaml
from torch.backends import cudnn
from torch.hub import load_state_dict_from_url
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data import DataLoader
from torchnet.meter import AverageValueMeter
from torchvision.models.detection.mask_rcnn import model_urls
from tqdm import tqdm

from happypose.pose_estimators.cosypose.cosypose.config import EXP_DIR
from happypose.pose_estimators.cosypose.cosypose.datasets.detection_dataset import (
DetectionDataset,
collate_fn
collate_fn,
)
from happypose.pose_estimators.cosypose.cosypose.datasets.samplers import PartialSampler
from happypose.pose_estimators.cosypose.cosypose.integrated.detector import Detector

# Evaluation
Expand All @@ -35,25 +34,17 @@
sync_model,
)
from happypose.pose_estimators.cosypose.cosypose.utils.logging import get_logger
from happypose.pose_estimators.cosypose.cosypose.utils.multiepoch_dataloader import (
MultiEpochDataLoader,
)
from happypose.toolbox.datasets.datasets_cfg import make_scene_dataset

from happypose.toolbox.datasets.scene_dataset import (
IterableMultiSceneDataset,
IterableSceneDataset,
RandomIterableSceneDataset,
SceneDataset,
)

from happypose.toolbox.utils.resources import (
get_cuda_memory,
get_gpu_memory,
get_total_memory,
)


from .detector_models_cfg import check_update_config, create_model_detector
from .maskrcnn_forward_loss import h_maskrcnn

Expand Down Expand Up @@ -154,7 +145,6 @@ def train_detector(args):
args.n_gpus = world_size
args.global_batch_size = world_size * args.batch_size
logger.info(f"Connection established with {world_size} gpus.")


# Make train/val datasets
def make_datasets(dataset_names):
Expand All @@ -168,9 +158,11 @@ def make_datasets(dataset_names):
logger.info(f"Loaded {ds_name} with {len(ds)} images.")
pre_label = ds_name.split(".")[0]
for idx, label in enumerate(ds.all_labels):
ds.all_labels[idx] = "{pre_label}-{label}".format(pre_label=pre_label, label=label)
ds.all_labels[idx] = "{pre_label}-{label}".format(
pre_label=pre_label, label=label
)
all_labels = all_labels.union(set(ds.all_labels))

for _ in range(n_repeat):
datasets.append(iterator)
return IterableMultiSceneDataset(datasets), all_labels
Expand All @@ -193,13 +185,13 @@ def make_datasets(dataset_names):
"gray_augmentation": args.gray_augmentation,
"label_to_category_id": label_to_category_id,
}

print("make datasets")
ds_train = DetectionDataset(scene_ds_train, **ds_kwargs)
ds_val = DetectionDataset(scene_ds_val, **ds_kwargs)

print("make dataloaders")
#train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
# train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
ds_iter_train = DataLoader(
ds_train,
batch_size=args.batch_size,
Expand All @@ -210,7 +202,7 @@ def make_datasets(dataset_names):
)
ds_iter_train = iter(ds_iter_train)

#val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size))
# val_sampler = PartialSampler(ds_val, epoch_size=int(0.1 * args.epoch_size))
ds_iter_val = DataLoader(
ds_val,
batch_size=args.batch_size,
Expand Down Expand Up @@ -334,7 +326,9 @@ def train_epoch():
max_norm=np.inf,
norm_type=2,
)
meters_train["grad_norm"].add(torch.as_tensor(total_grad_norm).item())
meters_train["grad_norm"].add(
torch.as_tensor(total_grad_norm).item()
)

optimizer.step()
meters_time["backward"].add(time.time() - t)
Expand Down Expand Up @@ -364,8 +358,8 @@ def validation():
if epoch % args.val_epoch_interval == 0:
validation()

#test_dict = None
#if epoch % args.test_epoch_interval == 0:
# test_dict = None
# if epoch % args.test_epoch_interval == 0:
# model.eval()
# test_dict = run_eval(args, model, epoch)

Expand All @@ -387,7 +381,6 @@ def validation():
},
)


for string, meters in zip(("train", "val"), (meters_train, meters_val)):
for k in dict(meters).keys():
log_dict[f"{string}_{k}"] = meters[k].mean
Expand Down
Loading

0 comments on commit 02b71cd

Please sign in to comment.