From 06a9a03b637c848e9c4d49e92444d523370a0c2f Mon Sep 17 00:00:00 2001 From: Little-Podi Date: Sun, 6 Oct 2024 10:34:17 +0800 Subject: [PATCH] eliminate sdata issue --- docs/INSTALL.md | 1 - docs/ISSUES.md | 16 +-------- vwm/data/__init__.py | 1 - vwm/data/dataset.py | 77 -------------------------------------------- 4 files changed, 1 insertion(+), 94 deletions(-) delete mode 100644 vwm/data/__init__.py diff --git a/docs/INSTALL.md b/docs/INSTALL.md index 6059d42..8e92aa9 100644 --- a/docs/INSTALL.md +++ b/docs/INSTALL.md @@ -40,7 +40,6 @@ ```shell conda install -y pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia pip3 install -r requirements.txt - pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata ``` --- diff --git a/docs/ISSUES.md b/docs/ISSUES.md index f8223ff..91388ba 100644 --- a/docs/ISSUES.md +++ b/docs/ISSUES.md @@ -17,21 +17,7 @@ 1. Download [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14/tree/main) and [laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/tree/main) in advance. 2. Set *version* of FrozenCLIPEmbedder and FrozenOpenCLIPImageEmbedder in `vwm/modules/encoders/modules.py` to the new paths of `pytorch_model.bin`/`open_clip_pytorch_model.bin`. -3. #### Datasets not yet available during training. - - - Possible reason: - - - The installed [sdata](https://github.com/Stability-AI/datapipelines) is not detected. - - - Try this: - - - Reinstall in the current project directory. - - ````shell - pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata - ```` - -4. #### The shapes of linear layers cannot be multiplied at the cross-attention layers. +3. #### The shapes of linear layers cannot be multiplied at the cross-attention layers. - Possible reason: - The dimension of cross-attention is not expended while the action conditions are injected, resulting in a mismatch. diff --git a/vwm/data/__init__.py b/vwm/data/__init__.py deleted file mode 100644 index 7664a25..0000000 --- a/vwm/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataset import StableDataModuleFromConfig diff --git a/vwm/data/dataset.py b/vwm/data/dataset.py index 44135c4..e260390 100644 --- a/vwm/data/dataset.py +++ b/vwm/data/dataset.py @@ -1,87 +1,10 @@ import random -from typing import Optional -import torchdata.datapipes.iter -import webdataset as wds -from omegaconf import DictConfig from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset from .subsets import YouTubeDataset, NuScenesDataset -try: - from sdata import create_dataset, create_dummy_dataset, create_loader -except ImportError as e: - print("#" * 100) - print("Datasets not yet available") - print("To enable, we need to add stable-datasets as a submodule") - print("Please use ``git submodule update --init --recursive``") - print("and do ``pip install -e stable-datasets/`` from the root of this repo") - print("#" * 100) - exit(1) - - -class StableDataModuleFromConfig(LightningDataModule): - def __init__( - self, - train: DictConfig, - validation: Optional[DictConfig] = None, - test: Optional[DictConfig] = None, - skip_val_loader: bool = False, - dummy: bool = False - ): - super().__init__() - self.train_config = train - assert ( - "datapipeline" in self.train_config and "loader" in self.train_config - ), "Train config requires the fields `datapipeline` and `loader`" - - self.val_config = validation - if not skip_val_loader: - if self.val_config is not None: - assert ( - "datapipeline" in self.val_config and "loader" in self.val_config - ), "Validation config requires the fields `datapipeline` and `loader`" - else: - print( - "WARNING: no validation datapipeline defined, using that one from training" - ) - self.val_config = train - - self.test_config = test - if self.test_config is not None: - assert ( - "datapipeline" in self.test_config and "loader" in self.test_config - ), "Test config requires the fields `datapipeline` and `loader`" - - self.dummy = dummy - if self.dummy: - print("#" * 100) - print("Using dummy dataset, hope you are debugging") - print("#" * 100) - - def setup(self, stage: str) -> None: - print("Preparing datasets") - if self.dummy: - data_fn = create_dummy_dataset - else: - data_fn = create_dataset - - self.train_data_pipeline = data_fn(**self.train_config.datapipeline) - if self.val_config: - self.val_data_pipeline = data_fn(**self.val_config.datapipeline) - if self.test_config: - self.test_data_pipeline = data_fn(**self.test_config.datapipeline) - - def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: - return create_loader(self.train_data_pipeline, **self.train_config.loader) - - def val_dataloader(self) -> wds.DataPipeline: - return create_loader(self.val_data_pipeline, **self.val_config.loader) - - def test_dataloader(self) -> wds.DataPipeline: - return create_loader(self.test_data_pipeline, **self.test_config.loader) - def dataset_mapping(subset_list: list, target_height: int, target_width: int, num_frames: int): datasets = list()