Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sample_rate parameter to audio decoder #551

Merged
merged 19 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def __init__(
source: Union[str, Path, bytes, Tensor],
*,
stream_index: Optional[int] = None,
sample_rate: Optional[int] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also consider exposing this as desired_sample_rate parameter? I don't have a strong opinion, the docs would make it clear what this means in any case.

):
self._decoder = create_decoder(source=source, seek_mode="approximate")

core.add_audio_stream(self._decoder, stream_index=stream_index)
core.add_audio_stream(
self._decoder, stream_index=stream_index, sample_rate=sample_rate
)

(
self.metadata,
Expand All @@ -39,6 +42,9 @@ def __init__(
decoder=self._decoder, stream_index=stream_index, media_type="audio"
)
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
self._desired_sample_rate = (
sample_rate if sample_rate is not None else self.metadata.sample_rate
)

def get_samples_played_in_range(
self, start_seconds: float, stop_seconds: Optional[float] = None
Expand Down Expand Up @@ -75,11 +81,7 @@ def get_samples_played_in_range(
# So we do some basic math to figure out the position of the view that
# we'll return.

# TODO: sample_rate is either the original one from metadata, or the
# user-specified one (NIY)
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
sample_rate = self.metadata.sample_rate

sample_rate = self._desired_sample_rate
# TODO: metadata's sample_rate should probably not be Optional
assert sample_rate is not None # mypy.

Expand All @@ -94,7 +96,7 @@ def get_samples_played_in_range(
output_pts_seconds = first_pts

num_samples = frames.shape[1]
last_pts = first_pts + num_samples / self.metadata.sample_rate
last_pts = first_pts + num_samples / sample_rate
if stop_seconds is not None and stop_seconds < last_pts:
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
else:
Expand Down
13 changes: 7 additions & 6 deletions src/torchcodec/decoders/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,21 @@ void setChannelLayout(

SwrContext* allocateSwrContext(
UniqueAVCodecContext& avCodecContext,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat) {
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
SwrContext* swrContext = nullptr;
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
AVChannelLayout layout = avCodecContext->ch_layout;
auto status = swr_alloc_set_opts2(
&swrContext,
&layout,
desiredSampleFormat,
sampleRate,
desiredSampleRate,
&layout,
sourceSampleFormat,
sampleRate,
sourceSampleRate,
0,
nullptr);

Expand All @@ -113,10 +114,10 @@ SwrContext* allocateSwrContext(
nullptr,
layout,
desiredSampleFormat,
sampleRate,
desiredSampleRate,
layout,
sourceSampleFormat,
sampleRate,
sourceSampleRate,
0,
nullptr);
#endif
Expand Down
5 changes: 3 additions & 2 deletions src/torchcodec/decoders/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ void setChannelLayout(
const UniqueAVFrame& srcAVFrame);
SwrContext* allocateSwrContext(
UniqueAVCodecContext& avCodecContext,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat);
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

// Returns true if sws_scale can handle unaligned data.
bool canSwsScaleHandleUnalignedData();
Expand Down
140 changes: 104 additions & 36 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,14 +546,18 @@ void VideoDecoder::addVideoStream(
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
}

void VideoDecoder::addAudioStream(int streamIndex) {
void VideoDecoder::addAudioStream(
int streamIndex,
const AudioStreamOptions& audioStreamOptions) {
TORCH_CHECK(
seekMode_ == SeekMode::approximate,
"seek_mode must be 'approximate' for audio streams.");

addStream(streamIndex, AVMEDIA_TYPE_AUDIO);

auto& streamInfo = streamInfos_[activeStreamIndex_];
streamInfo.audioStreamOptions = audioStreamOptions;

auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
streamMetadata.sampleRate =
Expand Down Expand Up @@ -913,6 +917,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
(stopPts <= lastDecodedAvFrameEnd);
}

torch::Tensor lastSamples = maybeFlushSwrBuffers();
if (lastSamples.numel() > 0) {
frames.push_back(lastSamples);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly fond of the above. Maybe we could let maybeFlushSwrBuffers return a tensor of shape (numChannels, 0), which could probably be pushed_back() unconditionally. Not sure that's better, there are probably nicer patterns I'm not seeing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything I can think of that does things unconditionally in this function is way too cute (returning a vector of tensors; using std::copy()). It may be more clear about intent if maybeFlushSwrBuffers() returns an optional so that then we don't need to use an empty tensor to indicate nothing to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using optional sounds better, thanks


return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
}

Expand Down Expand Up @@ -1166,8 +1175,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
getDuration(avFrame),
formatContext_->streams[activeStreamIndex_]->time_base);
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
convertAudioAVFrameToFrameOutputOnCPU(
avFrame, frameOutput, preAllocatedOutputTensor);
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
convertAVFrameToFrameOutputOnCPU(
avFrame, frameOutput, preAllocatedOutputTensor);
Expand Down Expand Up @@ -1345,24 +1353,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(

void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
UniqueAVFrame& srcAVFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
TORCH_CHECK(
!preAllocatedOutputTensor.has_value(),
"pre-allocated audio tensor not supported yet.");

FrameOutput& frameOutput) {
AVSampleFormat sourceSampleFormat =
static_cast<AVSampleFormat>(srcAVFrame->format);
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;

int sourceSampleRate = srcAVFrame->sample_rate;
int desiredSampleRate =
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
sourceSampleRate);

bool mustConvert =
(sourceSampleFormat != desiredSampleFormat ||
sourceSampleRate != desiredSampleRate);

UniqueAVFrame convertedAVFrame;
if (sourceSampleFormat != desiredSampleFormat) {
convertedAVFrame = convertAudioAVFrameSampleFormat(
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
if (mustConvert) {
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
srcAVFrame,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);
}
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
? convertedAVFrame
: srcAVFrame;
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;

AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
TORCH_CHECK(
Expand All @@ -1385,55 +1399,107 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
memcpy(
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
}

frameOutput.data = outputData;
}

UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
const UniqueAVFrame& avFrame,
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueAVFrame& srcAVFrame,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat

) {
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
auto& streamInfo = streamInfos_[activeStreamIndex_];
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());

if (!streamInfo.swrContext) {
createSwrContext(
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
streamInfo,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);
}

UniqueAVFrame convertedAVFrame(av_frame_alloc());
TORCH_CHECK(
convertedAVFrame,
"Could not allocate frame for sample format conversion.");

setChannelLayout(convertedAVFrame, avFrame);
setChannelLayout(convertedAVFrame, srcAVFrame);
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
convertedAVFrame->sample_rate = avFrame->sample_rate;
convertedAVFrame->nb_samples = avFrame->nb_samples;
convertedAVFrame->sample_rate = desiredSampleRate;
if (sourceSampleRate != desiredSampleRate) {
// Note that this is an upper bound on the number of output samples.
// `swr_convert()` will likely not fill convertedAVFrame with that many
// samples if sample rate conversion is needed. It will buffer the last few
// ones because those require future samples. That's also why we reset
// nb_samples after the call to `swr_convert()`.
convertedAVFrame->nb_samples = av_rescale_rnd(
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) +
srcAVFrame->nb_samples,
desiredSampleRate,
sourceSampleRate,
AV_ROUND_UP);
} else {
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
}

auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
TORCH_CHECK(
status == AVSUCCESS,
"Could not allocate frame buffers for sample format conversion: ",
getFFMPEGErrorStringFromErrorCode(status));

auto numSampleConverted = swr_convert(
auto numConvertedSamples = swr_convert(
streamInfo.swrContext.get(),
convertedAVFrame->data,
convertedAVFrame->nb_samples,
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)),
avFrame->nb_samples);
static_cast<const uint8_t**>(
const_cast<const uint8_t**>(srcAVFrame->data)),
srcAVFrame->nb_samples);
TORCH_CHECK(
numSampleConverted > 0,
numConvertedSamples > 0,
"Error in swr_convert: ",
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));

// See comment above about nb_samples
convertedAVFrame->nb_samples = numConvertedSamples;

return convertedAVFrame;
}

torch::Tensor VideoDecoder::maybeFlushSwrBuffers() {
// When sample rate conversion is involved, swresample buffers some of the
// samples in-between calls to swr_convert (see the libswresample docs).
// That's because the last few samples in a given frame require future samples
// from the next frame to be properly converted. This function flushes out the
// samples that are stored in swresample's buffers.
auto& streamInfo = streamInfos_[activeStreamIndex_];
if (!streamInfo.swrContext) {
return torch::empty({0, 0});
}
auto numRemainingSamples = // this is an upper bound
swr_get_out_samples(streamInfo.swrContext.get(), 0);

if (numRemainingSamples == 0) {
return torch::empty({0, 0});
}

torch::Tensor lastSamples = torch::empty(
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
torch::kFloat32);
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());

auto actualNumRemainingSamples = swr_convert(
streamInfo.swrContext.get(),
&lastSamplesData,
numRemainingSamples,
NULL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/NULL/nullptr/g

0);
return lastSamples.narrow(
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
}

// --------------------------------------------------------------------------
// OUTPUT ALLOCATION AND SHAPE CONVERSION
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -1669,14 +1735,16 @@ void VideoDecoder::createSwsContext(

void VideoDecoder::createSwrContext(
StreamInfo& streamInfo,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat) {
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
auto swrContext = allocateSwrContext(
streamInfo.codecContext,
sampleRate,
sourceSampleFormat,
desiredSampleFormat);
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);

auto status = swr_init(swrContext);
TORCH_CHECK(
Expand Down
29 changes: 21 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,18 @@ class VideoDecoder {
torch::Device device = torch::kCPU;
};

struct AudioStreamOptions {
AudioStreamOptions() {}

std::optional<int> sampleRate;
};

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the indirection of having an options struct. I know it mirrors the pattern established on the video side, but I also don't think it's a good practice there, either. It's harder to get rid of the video options because we accept a string, and then we do a bunch of work parsing the string. Getting rid of VideoStreamOptions will mean updating a bunch of callers to pass real arguments instead of a string.

Copy link
Member Author

@NicolasHug NicolasHug Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of the stringy video options either - that's why I didn't implement a string constructor for audio options. I also opened #577 to entirely remove string video options (it's not as hard as we thought).

I don't feel very strongly about this, but if we were to collapse both video options and audio options into the video decoder (or the StreamInfo), then we would have a lot of video-only and audio-only fields within the same struct. I personally find it cleaner to separate those into separate structs. Additionally, using option structs makes it very clear which fields/values come from user-specified parameter, in constrast to e.g. metadata or video properties, which is often useful to immediately understand the source of the values, as e.g. here:

int sourceSampleRate = srcAVFrame->sample_rate;
int desiredSampleRate =
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
sourceSampleRate);

LMK you thoughts, I'm fine with collapsing sampleRate as a StreamInfo field if you prefer, but we'd potentially be losing the 2 benefits mentioned above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good point, we actually store the options. I was thinking just from the addAudioStream() API perspective. Let's keep this then, and eventually get rid of the stringy options for video.

void addVideoStream(
int streamIndex,
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
void addAudioStream(int streamIndex);
void addAudioStream(
int streamIndex,
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());

// --------------------------------------------------------------------------
// DECODING AND SEEKING APIs
Expand Down Expand Up @@ -336,6 +344,7 @@ class VideoDecoder {
int64_t lastDecodedAvFramePts = 0;
int64_t lastDecodedAvFrameDuration = 0;
VideoStreamOptions videoStreamOptions;
AudioStreamOptions audioStreamOptions;

// color-conversion fields. Only one of FilterGraphContext and
// UniqueSwsContext should be non-null.
Expand Down Expand Up @@ -382,8 +391,7 @@ class VideoDecoder {

void convertAudioAVFrameToFrameOutputOnCPU(
UniqueAVFrame& srcAVFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
FrameOutput& frameOutput);

torch::Tensor convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame);
Expand All @@ -392,10 +400,14 @@ class VideoDecoder {
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor);

UniqueAVFrame convertAudioAVFrameSampleFormat(
const UniqueAVFrame& avFrame,
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueAVFrame& srcAVFrame,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat);
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

torch::Tensor maybeFlushSwrBuffers();

// --------------------------------------------------------------------------
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
Expand All @@ -413,9 +425,10 @@ class VideoDecoder {

void createSwrContext(
StreamInfo& streamInfo,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat);
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

// --------------------------------------------------------------------------
// PTS <-> INDEX CONVERSIONS
Expand Down
Loading
Loading