Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#4 Implement AudioToTextPipeline #34

Merged
merged 41 commits into from
Sep 6, 2024

Conversation

botirk38
Copy link
Collaborator

Why?

This PR implements an AudioToTextHFPipeline for transcribing audio datasets from HuggingFace into text using SONAR. The key reasons for this implementation are:

  1. To provide a standardized pipeline for audio transcription tasks within our project.
  2. To leverage SONAR's speech-to-text capabilities for processing audio datasets from HuggingFace.
  3. To enable easy configuration and customization of audio transcription parameters.
  4. To integrate seamlessly with our existing Pipeline framework and HuggingFace datasets.
  5. To support batch processing for efficient handling of large audio datasets.

This implementation will improve our ability to process and transcribe audio datasets consistently, making it easier to prepare text data for further analysis or model training.

How?

Key technical decisions and implementations:

  1. Extended the existing Pipeline and PipelineConfig classes to create AudioToTextHFPipeline and AudioPipelineConfig.
  2. Integrated SONAR's SpeechToTextPipeline for audio transcription.
  3. Implemented batch processing of audio data with configurable batch sizes.
  4. Used torch.inference_mode() for efficient inference.
  5. Included detailed logging for better monitoring and debugging.
  6. Implemented error handling throughout the pipeline.
  7. Allowed for customization of encoder and decoder models, target language, and other transcription parameters.

Work in Progress (WIP) parts:

  • The error handling in the transcribe_audio method could be more specific to different types of errors that might occur during transcription.

Test plan

To test these changes, we should:

  1. Create unit tests for the AudioToTextHFPipeline and AudioPipelineConfig classes.
  2. Implement integration tests with sample audio datasets from HuggingFace.
  3. Test with different audio formats and languages to ensure robustness.
  4. Verify that batch processing works correctly for various batch sizes.
  5. Test the pipeline with different SONAR encoder and decoder models.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 19, 2024
@botirk38 botirk38 changed the title Implement AudioToTextPipeline #6 Implement AudioToTextPipeline Aug 2, 2024
@botirk38 botirk38 changed the title #6 Implement AudioToTextPipeline #4 Implement AudioToTextPipeline Aug 12, 2024


@dataclass
class AudioDatasetConfig(DatasetConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this class used ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we initialize audio huggingface datasets

)

# Ensure all embeddings are 2D
all_embeddings = [emb.unsqueeze(0) if emb.dim(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if the audio inputs have multiple channels ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can convert the multiple channels into mono channels because raising makes the pipeline less durable, we can do this by taking the mean across channels

audio_data['array'], audio_data['sampling_rate'])
audio_inputs.append(temp_file.name)
else:
logger.warning(f"Invalid audio data format: {audio_data}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printing all info of audio_data might be overwhelming, especially in the terminal

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this logging in trace mode, it may be useful to ensure wav dim and shape is correct

)

"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype as an extra param here is needed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Comment on lines +136 to +138
if column not in batch:
logger.warning(f"Column {column} not found in batch. Skipping.")
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would raise in this case instead of skipping

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

and "sampling_rate" in audio_data
):
# Handle multi-channel audio by taking the mean across channels
audio_array = audio_data["array"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it a convention used always by HF ? if not, we need to add it as a param.

audio-specific attributes and processing.

Attributes:
sampling_rate (int): The target sampling rate for audio data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave a comment about HF integration

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.


# Ensure all embeddings are 2D
processed_embeddings: List[torch.Tensor] = [
emb.unsqueeze(0) if emb.dim() == 1 else emb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really need to do this ?

for emb in all_embeddings
]

# Get the maximum sequence length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we dont need this padding at all since it's already padded ?

Comment on lines 217 to 221
logger.error(
f"Error in model.predict for column {column}: {str(e)}"
)
# Instead of raising, we'll set the output to None and continue processing
batch[f"{column}_{self.config.output_column_suffix}"] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just raise to remain simple

Args:
dataset (datasets.Dataset): The loaded dataset.

Returns:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Link for casting column HF api add her

datasets.Dataset: The dataset with processed audio column.
"""
if self.audio_column in dataset.column_names:
dataset = dataset.cast_column(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to docs, that casting column will modify original column.

fbank_dtype: torch.dtype = torch.float32
n_parallel: int = 4
pad_idx: int = 0
dtype: np.dtype = np.dtype(np.float32)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need np.dtype constructor


try:
# Move tensors to the specified device
audio_inputs = [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move in mini batches, to stop out of memory errors.

@pytest.fixture
def complex_audio_data() -> Dict[str, Dict[str, Any]]:
return {
"short_audio": {"array": np.random.rand(8000), "sampling_rate": 16000},
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify correctness by decoding output tensors.

@botirk38 botirk38 merged commit 8c5bf17 into facebookresearch:main Sep 6, 2024
5 checks passed
@botirk38 botirk38 deleted the feature/speech-pipeline branch September 6, 2024 11:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants