Skip to content

Commit

Permalink
eliminate sdata issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Little-Podi committed Oct 6, 2024
1 parent 2150ab2 commit 06a9a03
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 94 deletions.
1 change: 0 additions & 1 deletion docs/INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
```shell
conda install -y pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip3 install -r requirements.txt
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
```
---
Expand Down
16 changes: 1 addition & 15 deletions docs/ISSUES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,7 @@
1. Download [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14/tree/main) and [laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/tree/main) in advance.
2. Set *version* of FrozenCLIPEmbedder and FrozenOpenCLIPImageEmbedder in `vwm/modules/encoders/modules.py` to the new paths of `pytorch_model.bin`/`open_clip_pytorch_model.bin`.

3. #### Datasets not yet available during training.

- Possible reason:

- The installed [sdata](https://github.com/Stability-AI/datapipelines) is not detected.

- Try this:

- Reinstall in the current project directory.

````shell
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
````

4. #### The shapes of linear layers cannot be multiplied at the cross-attention layers.
3. #### The shapes of linear layers cannot be multiplied at the cross-attention layers.

- Possible reason:
- The dimension of cross-attention is not expended while the action conditions are injected, resulting in a mismatch.
Expand Down
1 change: 0 additions & 1 deletion vwm/data/__init__.py

This file was deleted.

77 changes: 0 additions & 77 deletions vwm/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,87 +1,10 @@
import random
from typing import Optional

import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset

from .subsets import YouTubeDataset, NuScenesDataset

try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
print("#" * 100)
print("Datasets not yet available")
print("To enable, we need to add stable-datasets as a submodule")
print("Please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)


class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "Train config requires the fields `datapipeline` and `loader`"

self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "Validation config requires the fields `datapipeline` and `loader`"
else:
print(
"WARNING: no validation datapipeline defined, using that one from training"
)
self.val_config = train

self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "Test config requires the fields `datapipeline` and `loader`"

self.dummy = dummy
if self.dummy:
print("#" * 100)
print("Using dummy dataset, hope you are debugging")
print("#" * 100)

def setup(self, stage: str) -> None:
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset

self.train_data_pipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_data_pipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_data_pipeline = data_fn(**self.test_config.datapipeline)

def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
return create_loader(self.train_data_pipeline, **self.train_config.loader)

def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_data_pipeline, **self.val_config.loader)

def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_data_pipeline, **self.test_config.loader)


def dataset_mapping(subset_list: list, target_height: int, target_width: int, num_frames: int):
datasets = list()
Expand Down

0 comments on commit 06a9a03

Please sign in to comment.