diff --git a/icevision/core/record_components.py b/icevision/core/record_components.py index 1fff4c670..873e5d7ca 100644 --- a/icevision/core/record_components.py +++ b/icevision/core/record_components.py @@ -274,6 +274,7 @@ def _autofix(self) -> Dict[str, bool]: def _remove_annotation(self, i): self.label_ids.pop(i) + self.labels.pop(i) def _aggregate_objects(self) -> Dict[str, List[dict]]: return {**super()._aggregate_objects(), "labels": self.label_ids} diff --git a/icevision/data/dataset.py b/icevision/data/dataset.py index 260b8f9b2..6f67600e1 100644 --- a/icevision/data/dataset.py +++ b/icevision/data/dataset.py @@ -32,13 +32,16 @@ def __len__(self): return len(self.records) def __getitem__(self, i): - record = self.records[i].load() - if self.tfm is not None: - record = self.tfm(record) + if isinstance(i, slice): + return self.__class__(self.records[i], self.tfm) else: - # HACK FIXME - record.set_img(np.array(record.img)) - return record + record = self.records[i].load() + if self.tfm is not None: + record = self.tfm(record) + else: + # HACK FIXME + record.set_img(np.array(record.img)) + return record def __repr__(self): return f"<{self.__class__.__name__} with {len(self.records)} items>" diff --git a/setup.cfg b/setup.cfg index 7ee68bac5..28c47d023 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,7 +48,7 @@ inference = all = fastai >=2.5.2,<2.6 - pytorch-lightning >=1.4.5 + pytorch-lightning >=1.4.5,<1.7.0 wandb >=0.10.7 dev =