diff --git a/src/mgds/pipelineModules/DiskCache.py b/src/mgds/pipelineModules/DiskCache.py index 9eeaf89..df54daa 100644 --- a/src/mgds/pipelineModules/DiskCache.py +++ b/src/mgds/pipelineModules/DiskCache.py @@ -220,7 +220,7 @@ def fn(group_index, in_index, in_variation): self.aggregate_cache[group_key][in_variation] = \ torch.load(os.path.realpath(os.path.join(cache_dir, 'aggregate.pt')), weights_only=False) - def __get_input_index(self, out_variation: int, out_index: int) -> (str, int, int): + def __get_input_index(self, out_variation: int, out_index: int) -> tuple[str, int, int, int]: offset = 0 for group_key, group_output_samples in self.group_output_samples.items(): if out_index >= group_output_samples + offset: