Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing weight with multiplier #105

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions ultravox/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class DataDictConfig(BaseModel):
splits: List[str] = dataclasses.field(default_factory=list)
num_samples: Optional[int] = None
total_samples: int = 1
# epochs mode:
# Weight is the number of copies of the dataset
# max steps mode:
# Weight of the dataset is used to calculate the proporition of the total samples that comes from this dataset
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
weight: float = 1.0
streaming: bool = True
user_template: str = "<|audio|>"
Expand Down
30 changes: 18 additions & 12 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class SizedIterableDataset(abc.ABC, data.IterableDataset):
"""

@abc.abstractmethod
def __len__(self):
def __len__(self) -> int:
pass


Expand Down Expand Up @@ -373,7 +373,7 @@ def __iter__(self):
f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update."
)

def __len__(self):
def __len__(self) -> int:
return self._estimated_length

@abc.abstractmethod
Expand Down Expand Up @@ -493,7 +493,7 @@ def __init__(self, estimated_length: int = 1) -> None:
def __iter__(self):
return iter([])

def __len__(self):
def __len__(self) -> int:
return self._estimated_length


Expand Down Expand Up @@ -1049,16 +1049,17 @@ def __init__(
if self._args.shuffle:
dataset = dataset.shuffle(seed=self._args.shuffle_seed)

if config.num_samples:
dataset = Range(dataset, config.num_samples, config.total_samples)

self._weight = config.weight

self.user_template = config.user_template
self.assistant_template = config.assistant_template
self.transcript_template = config.transcript_template

super()._init_dataset(dataset, config.total_samples)
if config.num_samples:
dataset = Range(dataset, config.num_samples, config.total_samples)
super()._init_dataset(dataset, len(dataset))
else:
super()._init_dataset(dataset, config.total_samples)

def _get_sample(self, row) -> VoiceSample:
try:
Expand Down Expand Up @@ -1129,6 +1130,7 @@ def __init__(
stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED,
seed: Optional[int] = 42,
static: bool = False,
using_epochs: bool = False,
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Args:
Expand All @@ -1142,8 +1144,12 @@ def __init__(
self._static = static

self._stop_strategy = stop_strategy
self._using_epochs = using_epochs
if not self._using_epochs:
weights = [int(getattr(ds, "weight", 1) * len(ds)) for ds in datasets]
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
else:
weights = [getattr(ds, "weight", 1) for ds in datasets]

weights = [getattr(ds, "weight", 1) for ds in datasets]
total_weight = sum(weights)
self._normalized_probs = [w / total_weight for w in weights]

Expand Down Expand Up @@ -1180,9 +1186,9 @@ def __iter__(self):
iters[iter_index] = iter(self._datasets[iter_index])
yield next(iters[iter_index])

def __len__(self):
def __len__(self) -> int:
# TODO: Implement the length method for different stop strategies
return sum(len(ds) for ds in self._datasets)
return sum(int(getattr(ds, "weight", 1) * len(ds)) for ds in self._datasets)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved


class Dataproc(SizedIterableDataset):
Expand All @@ -1198,7 +1204,7 @@ def _process(self, sample: VoiceSample) -> Dict[str, Any]:
def __iter__(self):
return (self._process(sample) for sample in self._dataset)

def __len__(self):
def __len__(self) -> int:
return len(self._dataset)


Expand Down Expand Up @@ -1234,7 +1240,7 @@ def __iter__(self):
break
yield sample

def __len__(self):
def __len__(self) -> int:
return (
self._num_samples
if self._num_samples is not None
Expand Down
2 changes: 1 addition & 1 deletion ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class FakeSizedIterableDataset(datasets.SizedIterableDataset):
"""Fake version of datasets.SizedIterableDataset"""

def __init__(self, n, start=0, weight=1, estimated_length=0):
def __init__(self, n, start=0, weight=1, estimated_length=1):
self.data = range(start, start + n)
self._weight = weight
self._estimated_length = estimated_length
Expand Down
2 changes: 1 addition & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
text_model: "meta-llama/Meta-Llama-3-8B-Instruct"
audio_model: "facebook/wav2vec2-base-960h"

data_sets: ["gigaspeech"]
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
data_sets: []
val_sets: ["heysquad_human", "anyinstruct", "soda", "peoplespeech"]
stop_strategy: "LAST_EXHAUSTED"

Expand Down
9 changes: 7 additions & 2 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,20 @@ def prepare_dataset(
stop_strategy: datasets.StopStrategy,
num_samples: Optional[int] = None,
include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training)
is_val_set: bool = False,
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
) -> datasets.SizedIterableDataset:
data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names]
# If we're using epochs to train, validate the dataset length is appropriate.
if train_args.max_steps == 0:
using_epochs = train_args.max_steps == 0
if using_epochs and not is_val_set:
for ds in data_sets:
assert (
len(ds) > 1
), f"Dataset {ds} has length {len(ds)} which is too short for epoch training"

interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy)
interleave = datasets.InterleaveDataset(
data_sets, stop_strategy=stop_strategy, using_epochs=using_epochs
)
ds_with_proc = data_processing.UltravoxDataproc(
interleave,
processor=processor,
Expand Down Expand Up @@ -242,6 +246,7 @@ def train(args: config_base.TrainConfig):
num_samples=args.val_num_samples,
data_args=val_ds_args_text if k.startswith("text_") else val_ds_args,
include_alt_fields=model.loss_config.requires_alt_fields,
is_val_set=True,
)
for k in val_sets
}
Expand Down
Loading