Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training cosypose #149

Merged
merged 22 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
31be268
debugging cosypose training
Jan 24, 2024
d42fcca
1st step of fixing training + debug
ElliotMaitre Jan 29, 2024
05c0e2e
first working version of pose training
Jan 29, 2024
7dcd621
working towards detection training
ElliotMaitre Feb 16, 2024
682604e
adaptation of detector_training from cosypose to happypose
ElliotMaitre Mar 4, 2024
dbeda87
merging from dev
ElliotMaitre Mar 4, 2024
5376d72
Adding test for GPU. Changing test order to make it work
ElliotMaitre Mar 21, 2024
1f0259f
Merge branch 'dev' of github.com:agimus-project/happypose into traini…
ElliotMaitre Mar 21, 2024
3b6259d
adding pytest to poetry
ElliotMaitre Mar 25, 2024
400d14d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
5c00e07
removing debug code
ElliotMaitre Mar 25, 2024
02b71cd
merging?
ElliotMaitre Mar 25, 2024
8b537d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
02b412f
removing debug prints
ElliotMaitre Mar 25, 2024
75451cc
Merge branch 'training_cosypose' of github.com:agimus-project/happypo…
ElliotMaitre Mar 25, 2024
f5829a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
4bf9b57
fix bugs highlited in pre-commit ci
ElliotMaitre Mar 25, 2024
91ee620
Merge branch 'training_cosypose' of github.com:agimus-project/happypo…
ElliotMaitre Mar 25, 2024
16b1128
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
91e278a
removing unused data types
ElliotMaitre Mar 25, 2024
5a1801d
Merge branch 'training_cosypose' of github.com:agimus-project/happypo…
ElliotMaitre Mar 25, 2024
b1f47cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def __call__(self, im, mask, obs):

class VOCBackgroundAugmentation(BackgroundAugmentation):
def __init__(self, voc_root, p=0.3):
print("voc_root =", voc_root)
image_dataset = VOCSegmentation(
root=voc_root,
year="2012",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,30 @@
import torch

from happypose.pose_estimators.cosypose.cosypose.config import LOCAL_DATA_DIR

from .augmentations import (
CropResizeToAspectAugmentation,
from happypose.toolbox.datasets.augmentations import (
CropResizeToAspectTransform,
PillowBlur,
PillowBrightness,
PillowColor,
PillowContrast,
PillowSharpness,
VOCBackgroundAugmentation,
to_torch_uint8,
)
from .wrappers.visibility_wrapper import VisibilityWrapper
from happypose.toolbox.datasets.augmentations import (
SceneObservationAugmentation as SceneObsAug,
)

# HappyPose
# HappyPose
from happypose.toolbox.datasets.scene_dataset import (
IterableSceneDataset,
SceneDataset,
SceneObservation,
)
from happypose.toolbox.datasets.scene_dataset_wrappers import remove_invisible_objects


class DetectionDataset(torch.utils.data.Dataset):
class DetectionDataset(torch.utils.data.IterableDataset):
def __init__(
self,
scene_ds,
Expand All @@ -29,59 +38,76 @@ def __init__(
rgb_augmentation=False,
background_augmentation=False,
):
self.scene_ds = VisibilityWrapper(scene_ds)

self.resize_augmentation = CropResizeToAspectAugmentation(resize=resize)

self.background_augmentation = background_augmentation
self.background_augmentations = VOCBackgroundAugmentation(
voc_root=LOCAL_DATA_DIR / "VOCdevkit/VOC2012",
p=0.3,
)
self.scene_ds = scene_ds

self.resize_augmentation = CropResizeToAspectTransform()

self.background_augmentations = []
self.background_augmentations += [
(
SceneObsAug(
VOCBackgroundAugmentation(LOCAL_DATA_DIR),
p=0.3,
)
),
]

self.rgb_augmentation = rgb_augmentation
self.rgb_augmentations = [
PillowBlur(p=0.4, factor_interval=(1, 3)),
PillowSharpness(p=0.3, factor_interval=(0.0, 50.0)),
PillowContrast(p=0.3, factor_interval=(0.2, 50.0)),
PillowBrightness(p=0.5, factor_interval=(0.1, 6.0)),
PillowColor(p=0.3, factor_interval=(0.0, 20.0)),
self.rgb_augmentations = []
self.rgb_augmentations += [
SceneObsAug(
[
SceneObsAug(PillowBlur(factor_interval=(1, 3)), p=0.4),
SceneObsAug(PillowSharpness(factor_interval=(0.0, 50.0)), p=0.3),
SceneObsAug(PillowContrast(factor_interval=(0.2, 50.0)), p=0.3),
SceneObsAug(PillowBrightness(factor_interval=(0.1, 6.0)), p=0.5),
SceneObsAug(PillowColor(factor_interval=(0.0, 20.0)), p=0.3),
]
)
]

self.label_to_category_id = label_to_category_id
self.min_area = min_area

def __len__(self):
return len(self.scene_ds)
def make_data_from_obs(self, obs: SceneObservation, idx):
obs = remove_invisible_objects(obs)

def get_data(self, idx):
rgb, mask, state = self.scene_ds[idx]
obs = self.resize_augmentation(obs)

rgb, mask, state = self.resize_augmentation(rgb, mask, state)
for aug in self.background_augmentations:
obs = aug(obs)

if self.background_augmentation:
rgb, mask, state = self.background_augmentations(rgb, mask, state)

if self.rgb_augmentation and random.random() < 0.8:
for augmentation in self.rgb_augmentations:
rgb, mask, state = augmentation(rgb, mask, state)

rgb, mask = to_torch_uint8(rgb), to_torch_uint8(mask)
if self.rgb_augmentations and random.random() < 0.8:
for aug in self.rgb_augmentations:
obs = aug(obs)

assert obs.object_datas is not None
assert obs.rgb is not None
assert obs.camera_data is not None
categories = torch.tensor(
[self.label_to_category_id[obj["name"]] for obj in state["objects"]],
[self.label_to_category_id[obj.label] for obj in obs.object_datas],
)
obj_ids = np.array([obj["id_in_segm"] for obj in state["objects"]])
boxes = np.array(
[torch.as_tensor(obj["bbox"]).tolist() for obj in state["objects"]],
[torch.as_tensor(obj.bbox_modal).tolist() for obj in obs.object_datas],
)
boxes = torch.as_tensor(boxes, dtype=torch.float32).view(-1, 4)
area = torch.as_tensor(
(boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
)
mask = np.array(mask)
masks = mask == obj_ids[:, None, None]
masks = torch.as_tensor(masks)
obj_ids = np.array([obj.unique_id for obj in obs.object_datas])

masks = []
for _n, obj_data in enumerate(obs.object_datas):
if obs.binary_masks is not None:
binary_mask = torch.tensor(obs.binary_masks[obj_data.unique_id]).float()
masks.append(binary_mask)

if obs.segmentation is not None:
binary_mask = np.zeros_like(obs.segmentation, dtype=np.bool_)
binary_mask[obs.segmentation == obj_data.unique_id] = 1
binary_mask = torch.as_tensor(binary_mask).float()
masks.append(binary_mask)

masks = np.array(masks)
masks = masks == obj_ids[:, None, None]

keep = area > self.min_area
boxes = boxes[keep]
Expand All @@ -97,26 +123,36 @@ def get_data(self, idx):
image_id = torch.tensor([idx])
iscrowd = torch.zeros((num_objs), dtype=torch.int64)

rgb = torch.as_tensor(obs.rgb)
target = {}
target["boxes"] = boxes
target["labels"] = categories
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

return rgb, target

def __getitem__(self, index):
try_index = index
valid = False
def __getitem__(self, index: int):
assert isinstance(self.scene_ds, SceneDataset)
obs = self.scene_ds[index]
return self.make_data_from_obs(obs, index)

# def find_valid_data(self, iterator: Iterator[SceneObservation]) -> PoseData:
def find_valid_data(self, iterator):
n_attempts = 0
while not valid:
if n_attempts > 10:
for idx, obs in enumerate(iterator):
data = self.make_data_from_obs(obs, idx)
if data is not None:
return data
n_attempts += 1
if n_attempts > 200:
msg = "Cannot find valid image in the dataset"
raise ValueError(msg)
im, target = self.get_data(try_index)
valid = len(target["boxes"]) > 0
if not valid:
try_index = random.randint(0, len(self.scene_ds) - 1)
n_attempts += 1
return im, target

def __iter__(self):
assert isinstance(self.scene_ds, IterableSceneDataset)
iterator = iter(self.scene_ds)
while True:
yield self.find_valid_data(iterator)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
import yaml

from happypose.toolbox.datasets.datasets_cfg import make_urdf_dataset
from happypose.toolbox.datasets.scene_dataset import (
CameraData,
ObjectData,
ObservationInfos,
SceneObservation,
)
from happypose.toolbox.lib3d.transform import Transform

from .utils import make_detections_from_segmentation

Expand Down Expand Up @@ -68,12 +75,38 @@ def __getitem__(self, idx):
0
]
mask_uniqs = set(np.unique(mask[mask > 0]))
object_datas = []
for obj in objects:
if obj["id_in_segm"] in mask_uniqs:
obj["bbox"] = dets_gt[obj["id_in_segm"]].numpy()
state = {
"camera": cam,
"objects": objects,
"frame_info": self.frame_index.iloc[idx].to_dict(),
}
return rgb, mask, state
obj_data = ObjectData(
label=obj["name"],
TWO=Transform(obj["TWO"]),
unique_id=obj["body_id"],
bbox_modal=obj["bbox"],
)
else:
obj_data = ObjectData(
label=obj["name"],
TWO=Transform(obj["TWO"]),
unique_id=obj["body_id"],
)
object_datas.append(obj_data)

image_infos = ObservationInfos(
scene_id=self.frame_index.iloc[idx].to_dict()["scene_id"],
view_id=self.frame_index.iloc[idx].to_dict()["view_id"],
)

cam = CameraData(
K=cam["K"], resolution=cam["resolution"], TWC=Transform(cam["TWC"])
)

observation = SceneObservation(
rgb=rgb.numpy().astype(np.uint8),
segmentation=mask.numpy().astype(np.uint32),
camera_data=cam,
infos=image_infos,
object_datas=object_datas,
)
return observation
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def process_data(self, data):
ids_visible = set(ids_visible[ids_visible > 0])
visib_objects = []
for obj in state["objects"]:
if obj["id_in_segm"] in ids_visible:
if obj.unique_id in ids_visible:
visib_objects.append(obj)
state["objects"] = visib_objects
return rgb, mask, state
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,18 @@ def run_inference_pipeline(


"""
if self.inference_cfg.detection_type == "gt":
if self.inference_cfg["detection_type"] == "gt":
detections = gt_detections
run_detector = False
elif self.inference_cfg.detection_type == "detector":
elif self.inference_cfg["detection_type"] == "detector":
detections = None
run_detector = True
else:
msg = f"Unknown detection type {self.inference_cfg.detection_type}"
msg = f"Unknown detection type {self.inference_cfg['detection_type']}"
raise ValueError(msg)

coarse_estimates = None
if self.inference_cfg.coarse_estimation_type == "external":
if self.inference_cfg["coarse_estimation_type"] == "external":
# TODO (ylabbe): This is hacky, clean this for modelnet eval.
coarse_estimates = initial_estimates
coarse_estimates = happypose.toolbox.inference.utils.add_instance_id(
Expand Down Expand Up @@ -136,15 +136,15 @@ def run_inference_pipeline(
all_preds = {}
data_TCO_refiner = extra_data["refiner"]["preds"]

k_0 = f"refiner/iteration={self.inference_cfg.n_refiner_iterations}"
k_0 = f"refiner/iteration={self.inference_cfg['n_refiner_iterations']}"
all_preds = {
"final": preds,
k_0: data_TCO_refiner,
"refiner/final": data_TCO_refiner,
"coarse": extra_data["coarse"]["preds"],
}

if self.inference_cfg.run_depth_refiner:
if self.inference_cfg["run_depth_refiner"]:
all_preds["depth_refiner"] = extra_data["depth_refiner"]["preds"]

# Remove any mask tensors
Expand Down Expand Up @@ -173,43 +173,46 @@ def get_predictions(
"""
predictions_list = defaultdict(list)
for n, data in enumerate(tqdm(self.dataloader)):
# data is a dict
rgb = data["rgb"]
depth = None
K = data["cameras"].K
gt_detections = data["gt_detections"].cuda()

initial_data = None
if data["initial_data"]:
initial_data = data["initial_data"].cuda()

obs_tensor = ObservationTensor.from_torch_batched(rgb, depth, K)
obs_tensor = obs_tensor.cuda()

# GPU warmup for timing
if n == 0:
if n < 3:
# data is a dict
rgb = data["rgb"]
depth = None
K = data["cameras"].K
gt_detections = data["gt_detections"].cuda()

initial_data = None
if data["initial_data"]:
initial_data = data["initial_data"].cuda()

obs_tensor = ObservationTensor.from_torch_batched(rgb, depth, K)
obs_tensor = obs_tensor.cuda()

# GPU warmup for timing
if n == 0:
with torch.no_grad():
self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
initial_estimates=initial_data,
)

cuda_timer = CudaTimer()
cuda_timer.start()
with torch.no_grad():
self.run_inference_pipeline(
all_preds = self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
initial_estimates=initial_data,
)
cuda_timer.end()
cuda_timer.elapsed()

cuda_timer = CudaTimer()
cuda_timer.start()
with torch.no_grad():
all_preds = self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
initial_estimates=initial_data,
)
cuda_timer.end()
cuda_timer.elapsed()

for k, v in all_preds.items():
predictions_list[k].append(v)
for k, v in all_preds.items():
predictions_list[k].append(v)
else:
break

# Concatenate the lists of PandasTensorCollections
predictions = {}
Expand Down
Loading
Loading