Skip to content

Commit

Permalink
Merge pull request #149 from agimus-project/training_cosypose
Browse files Browse the repository at this point in the history
Training cosypose
  • Loading branch information
nim65s authored Mar 26, 2024
2 parents 0aba10a + b1f47cf commit deb0965
Show file tree
Hide file tree
Showing 29 changed files with 1,085 additions and 881 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __call__(self, im, mask, obs):

class VOCBackgroundAugmentation(BackgroundAugmentation):
def __init__(self, voc_root, p=0.3):
print("voc_root =", voc_root)
image_dataset = VOCSegmentation(
root=voc_root,
year="2012",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,30 @@
import torch

from happypose.pose_estimators.cosypose.cosypose.config import LOCAL_DATA_DIR

from .augmentations import (
CropResizeToAspectAugmentation,
from happypose.toolbox.datasets.augmentations import (
CropResizeToAspectTransform,
PillowBlur,
PillowBrightness,
PillowColor,
PillowContrast,
PillowSharpness,
VOCBackgroundAugmentation,
to_torch_uint8,
)
from .wrappers.visibility_wrapper import VisibilityWrapper
from happypose.toolbox.datasets.augmentations import (
SceneObservationAugmentation as SceneObsAug,
)

# HappyPose
# HappyPose
from happypose.toolbox.datasets.scene_dataset import (
IterableSceneDataset,
SceneDataset,
SceneObservation,
)
from happypose.toolbox.datasets.scene_dataset_wrappers import remove_invisible_objects


class DetectionDataset(torch.utils.data.Dataset):
class DetectionDataset(torch.utils.data.IterableDataset):
def __init__(
self,
scene_ds,
Expand All @@ -29,59 +38,76 @@ def __init__(
rgb_augmentation=False,
background_augmentation=False,
):
self.scene_ds = VisibilityWrapper(scene_ds)

self.resize_augmentation = CropResizeToAspectAugmentation(resize=resize)

self.background_augmentation = background_augmentation
self.background_augmentations = VOCBackgroundAugmentation(
voc_root=LOCAL_DATA_DIR / "VOCdevkit/VOC2012",
p=0.3,
)
self.scene_ds = scene_ds

self.resize_augmentation = CropResizeToAspectTransform()

self.background_augmentations = []
self.background_augmentations += [
(
SceneObsAug(
VOCBackgroundAugmentation(LOCAL_DATA_DIR),
p=0.3,
)
),
]

self.rgb_augmentation = rgb_augmentation
self.rgb_augmentations = [
PillowBlur(p=0.4, factor_interval=(1, 3)),
PillowSharpness(p=0.3, factor_interval=(0.0, 50.0)),
PillowContrast(p=0.3, factor_interval=(0.2, 50.0)),
PillowBrightness(p=0.5, factor_interval=(0.1, 6.0)),
PillowColor(p=0.3, factor_interval=(0.0, 20.0)),
self.rgb_augmentations = []
self.rgb_augmentations += [
SceneObsAug(
[
SceneObsAug(PillowBlur(factor_interval=(1, 3)), p=0.4),
SceneObsAug(PillowSharpness(factor_interval=(0.0, 50.0)), p=0.3),
SceneObsAug(PillowContrast(factor_interval=(0.2, 50.0)), p=0.3),
SceneObsAug(PillowBrightness(factor_interval=(0.1, 6.0)), p=0.5),
SceneObsAug(PillowColor(factor_interval=(0.0, 20.0)), p=0.3),
]
)
]

self.label_to_category_id = label_to_category_id
self.min_area = min_area

def __len__(self):
return len(self.scene_ds)
def make_data_from_obs(self, obs: SceneObservation, idx):
obs = remove_invisible_objects(obs)

def get_data(self, idx):
rgb, mask, state = self.scene_ds[idx]
obs = self.resize_augmentation(obs)

rgb, mask, state = self.resize_augmentation(rgb, mask, state)
for aug in self.background_augmentations:
obs = aug(obs)

if self.background_augmentation:
rgb, mask, state = self.background_augmentations(rgb, mask, state)

if self.rgb_augmentation and random.random() < 0.8:
for augmentation in self.rgb_augmentations:
rgb, mask, state = augmentation(rgb, mask, state)

rgb, mask = to_torch_uint8(rgb), to_torch_uint8(mask)
if self.rgb_augmentations and random.random() < 0.8:
for aug in self.rgb_augmentations:
obs = aug(obs)

assert obs.object_datas is not None
assert obs.rgb is not None
assert obs.camera_data is not None
categories = torch.tensor(
[self.label_to_category_id[obj["name"]] for obj in state["objects"]],
[self.label_to_category_id[obj.label] for obj in obs.object_datas],
)
obj_ids = np.array([obj["id_in_segm"] for obj in state["objects"]])
boxes = np.array(
[torch.as_tensor(obj["bbox"]).tolist() for obj in state["objects"]],
[torch.as_tensor(obj.bbox_modal).tolist() for obj in obs.object_datas],
)
boxes = torch.as_tensor(boxes, dtype=torch.float32).view(-1, 4)
area = torch.as_tensor(
(boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
)
mask = np.array(mask)
masks = mask == obj_ids[:, None, None]
masks = torch.as_tensor(masks)
obj_ids = np.array([obj.unique_id for obj in obs.object_datas])

masks = []
for _n, obj_data in enumerate(obs.object_datas):
if obs.binary_masks is not None:
binary_mask = torch.tensor(obs.binary_masks[obj_data.unique_id]).float()
masks.append(binary_mask)

if obs.segmentation is not None:
binary_mask = np.zeros_like(obs.segmentation, dtype=np.bool_)
binary_mask[obs.segmentation == obj_data.unique_id] = 1
binary_mask = torch.as_tensor(binary_mask).float()
masks.append(binary_mask)

masks = np.array(masks)
masks = masks == obj_ids[:, None, None]

keep = area > self.min_area
boxes = boxes[keep]
Expand All @@ -97,26 +123,36 @@ def get_data(self, idx):
image_id = torch.tensor([idx])
iscrowd = torch.zeros((num_objs), dtype=torch.int64)

rgb = torch.as_tensor(obs.rgb)
target = {}
target["boxes"] = boxes
target["labels"] = categories
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

return rgb, target

def __getitem__(self, index):
try_index = index
valid = False
def __getitem__(self, index: int):
assert isinstance(self.scene_ds, SceneDataset)
obs = self.scene_ds[index]
return self.make_data_from_obs(obs, index)

# def find_valid_data(self, iterator: Iterator[SceneObservation]) -> PoseData:
def find_valid_data(self, iterator):
n_attempts = 0
while not valid:
if n_attempts > 10:
for idx, obs in enumerate(iterator):
data = self.make_data_from_obs(obs, idx)
if data is not None:
return data
n_attempts += 1
if n_attempts > 200:
msg = "Cannot find valid image in the dataset"
raise ValueError(msg)
im, target = self.get_data(try_index)
valid = len(target["boxes"]) > 0
if not valid:
try_index = random.randint(0, len(self.scene_ds) - 1)
n_attempts += 1
return im, target

def __iter__(self):
assert isinstance(self.scene_ds, IterableSceneDataset)
iterator = iter(self.scene_ds)
while True:
yield self.find_valid_data(iterator)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
import yaml

from happypose.toolbox.datasets.datasets_cfg import make_urdf_dataset
from happypose.toolbox.datasets.scene_dataset import (
CameraData,
ObjectData,
ObservationInfos,
SceneObservation,
)
from happypose.toolbox.lib3d.transform import Transform

from .utils import make_detections_from_segmentation

Expand Down Expand Up @@ -68,12 +75,38 @@ def __getitem__(self, idx):
0
]
mask_uniqs = set(np.unique(mask[mask > 0]))
object_datas = []
for obj in objects:
if obj["id_in_segm"] in mask_uniqs:
obj["bbox"] = dets_gt[obj["id_in_segm"]].numpy()
state = {
"camera": cam,
"objects": objects,
"frame_info": self.frame_index.iloc[idx].to_dict(),
}
return rgb, mask, state
obj_data = ObjectData(
label=obj["name"],
TWO=Transform(obj["TWO"]),
unique_id=obj["body_id"],
bbox_modal=obj["bbox"],
)
else:
obj_data = ObjectData(
label=obj["name"],
TWO=Transform(obj["TWO"]),
unique_id=obj["body_id"],
)
object_datas.append(obj_data)

image_infos = ObservationInfos(
scene_id=self.frame_index.iloc[idx].to_dict()["scene_id"],
view_id=self.frame_index.iloc[idx].to_dict()["view_id"],
)

cam = CameraData(
K=cam["K"], resolution=cam["resolution"], TWC=Transform(cam["TWC"])
)

observation = SceneObservation(
rgb=rgb.numpy().astype(np.uint8),
segmentation=mask.numpy().astype(np.uint32),
camera_data=cam,
infos=image_infos,
object_datas=object_datas,
)
return observation
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def process_data(self, data):
ids_visible = set(ids_visible[ids_visible > 0])
visib_objects = []
for obj in state["objects"]:
if obj["id_in_segm"] in ids_visible:
if obj.unique_id in ids_visible:
visib_objects.append(obj)
state["objects"] = visib_objects
return rgb, mask, state
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,18 @@ def run_inference_pipeline(
"""
if self.inference_cfg.detection_type == "gt":
if self.inference_cfg["detection_type"] == "gt":
detections = gt_detections
run_detector = False
elif self.inference_cfg.detection_type == "detector":
elif self.inference_cfg["detection_type"] == "detector":
detections = None
run_detector = True
else:
msg = f"Unknown detection type {self.inference_cfg.detection_type}"
msg = f"Unknown detection type {self.inference_cfg['detection_type']}"
raise ValueError(msg)

coarse_estimates = None
if self.inference_cfg.coarse_estimation_type == "external":
if self.inference_cfg["coarse_estimation_type"] == "external":
# TODO (ylabbe): This is hacky, clean this for modelnet eval.
coarse_estimates = initial_estimates
coarse_estimates = happypose.toolbox.inference.utils.add_instance_id(
Expand Down Expand Up @@ -136,15 +136,15 @@ def run_inference_pipeline(
all_preds = {}
data_TCO_refiner = extra_data["refiner"]["preds"]

k_0 = f"refiner/iteration={self.inference_cfg.n_refiner_iterations}"
k_0 = f"refiner/iteration={self.inference_cfg['n_refiner_iterations']}"
all_preds = {
"final": preds,
k_0: data_TCO_refiner,
"refiner/final": data_TCO_refiner,
"coarse": extra_data["coarse"]["preds"],
}

if self.inference_cfg.run_depth_refiner:
if self.inference_cfg["run_depth_refiner"]:
all_preds["depth_refiner"] = extra_data["depth_refiner"]["preds"]

# Remove any mask tensors
Expand Down Expand Up @@ -173,43 +173,46 @@ def get_predictions(
"""
predictions_list = defaultdict(list)
for n, data in enumerate(tqdm(self.dataloader)):
# data is a dict
rgb = data["rgb"]
depth = None
K = data["cameras"].K
gt_detections = data["gt_detections"].cuda()

initial_data = None
if data["initial_data"]:
initial_data = data["initial_data"].cuda()

obs_tensor = ObservationTensor.from_torch_batched(rgb, depth, K)
obs_tensor = obs_tensor.cuda()

# GPU warmup for timing
if n == 0:
if n < 3:
# data is a dict
rgb = data["rgb"]
depth = None
K = data["cameras"].K
gt_detections = data["gt_detections"].cuda()

initial_data = None
if data["initial_data"]:
initial_data = data["initial_data"].cuda()

obs_tensor = ObservationTensor.from_torch_batched(rgb, depth, K)
obs_tensor = obs_tensor.cuda()

# GPU warmup for timing
if n == 0:
with torch.no_grad():
self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
initial_estimates=initial_data,
)

cuda_timer = CudaTimer()
cuda_timer.start()
with torch.no_grad():
self.run_inference_pipeline(
all_preds = self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
initial_estimates=initial_data,
)
cuda_timer.end()
cuda_timer.elapsed()

cuda_timer = CudaTimer()
cuda_timer.start()
with torch.no_grad():
all_preds = self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
initial_estimates=initial_data,
)
cuda_timer.end()
cuda_timer.elapsed()

for k, v in all_preds.items():
predictions_list[k].append(v)
for k, v in all_preds.items():
predictions_list[k].append(v)
else:
break

# Concatenate the lists of PandasTensorCollections
predictions = {}
Expand Down
Loading

0 comments on commit deb0965

Please sign in to comment.