Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sample_rate parameter to audio decoder #551

Merged
merged 19 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ def __init__(
source: Union[str, Path, bytes, Tensor],
*,
stream_index: Optional[int] = None,
sample_rate: Optional[int] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also consider exposing this as desired_sample_rate parameter? I don't have a strong opinion, the docs would make it clear what this means in any case.

):
self._decoder = create_decoder(source=source, seek_mode="approximate")

core.add_audio_stream(self._decoder, stream_index=stream_index)
core.add_audio_stream(
self._decoder, stream_index=stream_index, sample_rate=sample_rate
)

(
self.metadata,
Expand All @@ -39,6 +42,9 @@ def __init__(
decoder=self._decoder, stream_index=stream_index, media_type="audio"
)
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
self._desired_sample_rate = (
sample_rate if sample_rate is not None else self.metadata.sample_rate
)

def get_samples_played_in_range(
self, start_seconds: float, stop_seconds: Optional[float] = None
Expand Down Expand Up @@ -75,11 +81,7 @@ def get_samples_played_in_range(
# So we do some basic math to figure out the position of the view that
# we'll return.

# TODO: sample_rate is either the original one from metadata, or the
# user-specified one (NIY)
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy
sample_rate = self.metadata.sample_rate

sample_rate = self._desired_sample_rate
# TODO: metadata's sample_rate should probably not be Optional
assert sample_rate is not None # mypy.

Expand All @@ -94,7 +96,7 @@ def get_samples_played_in_range(
output_pts_seconds = first_pts

num_samples = frames.shape[1]
last_pts = first_pts + num_samples / self.metadata.sample_rate
last_pts = first_pts + num_samples / sample_rate
if stop_seconds is not None and stop_seconds < last_pts:
offset_end = num_samples - round((last_pts - stop_seconds) * sample_rate)
else:
Expand Down
13 changes: 7 additions & 6 deletions src/torchcodec/decoders/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,21 @@ void setChannelLayout(

SwrContext* allocateSwrContext(
UniqueAVCodecContext& avCodecContext,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat) {
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
SwrContext* swrContext = nullptr;
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
AVChannelLayout layout = avCodecContext->ch_layout;
auto status = swr_alloc_set_opts2(
&swrContext,
&layout,
desiredSampleFormat,
sampleRate,
desiredSampleRate,
&layout,
sourceSampleFormat,
sampleRate,
sourceSampleRate,
0,
nullptr);

Expand All @@ -113,10 +114,10 @@ SwrContext* allocateSwrContext(
nullptr,
layout,
desiredSampleFormat,
sampleRate,
desiredSampleRate,
layout,
sourceSampleFormat,
sampleRate,
sourceSampleRate,
0,
nullptr);
#endif
Expand Down
5 changes: 3 additions & 2 deletions src/torchcodec/decoders/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,10 @@ void setChannelLayout(
const UniqueAVFrame& srcAVFrame);
SwrContext* allocateSwrContext(
UniqueAVCodecContext& avCodecContext,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat);
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

// Returns true if sws_scale can handle unaligned data.
bool canSwsScaleHandleUnalignedData();
Expand Down
143 changes: 107 additions & 36 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,14 +580,18 @@ void VideoDecoder::addVideoStream(
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
}

void VideoDecoder::addAudioStream(int streamIndex) {
void VideoDecoder::addAudioStream(
int streamIndex,
const AudioStreamOptions& audioStreamOptions) {
TORCH_CHECK(
seekMode_ == SeekMode::approximate,
"seek_mode must be 'approximate' for audio streams.");

addStream(streamIndex, AVMEDIA_TYPE_AUDIO);

auto& streamInfo = streamInfos_[activeStreamIndex_];
streamInfo.audioStreamOptions = audioStreamOptions;

auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
streamMetadata.sampleRate =
Expand Down Expand Up @@ -947,6 +951,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
(stopPts <= lastDecodedAvFrameEnd);
}

auto lastSamples = maybeFlushSwrBuffers();
if (lastSamples.has_value()) {
frames.push_back(*lastSamples);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly fond of the above. Maybe we could let maybeFlushSwrBuffers return a tensor of shape (numChannels, 0), which could probably be pushed_back() unconditionally. Not sure that's better, there are probably nicer patterns I'm not seeing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything I can think of that does things unconditionally in this function is way too cute (returning a vector of tensors; using std::copy()). It may be more clear about intent if maybeFlushSwrBuffers() returns an optional so that then we don't need to use an empty tensor to indicate nothing to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using optional sounds better, thanks


return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds};
}

Expand Down Expand Up @@ -1200,8 +1209,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
getDuration(avFrame),
formatContext_->streams[activeStreamIndex_]->time_base);
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
convertAudioAVFrameToFrameOutputOnCPU(
avFrame, frameOutput, preAllocatedOutputTensor);
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput);
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
convertAVFrameToFrameOutputOnCPU(
avFrame, frameOutput, preAllocatedOutputTensor);
Expand Down Expand Up @@ -1379,24 +1387,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(

void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
UniqueAVFrame& srcAVFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
TORCH_CHECK(
!preAllocatedOutputTensor.has_value(),
"pre-allocated audio tensor not supported yet.");

FrameOutput& frameOutput) {
AVSampleFormat sourceSampleFormat =
static_cast<AVSampleFormat>(srcAVFrame->format);
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;

int sourceSampleRate = srcAVFrame->sample_rate;
int desiredSampleRate =
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
sourceSampleRate);

bool mustConvert =
(sourceSampleFormat != desiredSampleFormat ||
sourceSampleRate != desiredSampleRate);

UniqueAVFrame convertedAVFrame;
if (sourceSampleFormat != desiredSampleFormat) {
convertedAVFrame = convertAudioAVFrameSampleFormat(
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
if (mustConvert) {
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
srcAVFrame,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);
}
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
? convertedAVFrame
: srcAVFrame;
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;

AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
TORCH_CHECK(
Expand All @@ -1419,55 +1433,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
memcpy(
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel);
}

frameOutput.data = outputData;
}

UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat(
const UniqueAVFrame& avFrame,
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueAVFrame& srcAVFrame,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat

) {
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
auto& streamInfo = streamInfos_[activeStreamIndex_];
const auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value());

if (!streamInfo.swrContext) {
createSwrContext(
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat);
streamInfo,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);
}

UniqueAVFrame convertedAVFrame(av_frame_alloc());
TORCH_CHECK(
convertedAVFrame,
"Could not allocate frame for sample format conversion.");

setChannelLayout(convertedAVFrame, avFrame);
setChannelLayout(convertedAVFrame, srcAVFrame);
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
convertedAVFrame->sample_rate = avFrame->sample_rate;
convertedAVFrame->nb_samples = avFrame->nb_samples;
convertedAVFrame->sample_rate = desiredSampleRate;
if (sourceSampleRate != desiredSampleRate) {
// Note that this is an upper bound on the number of output samples.
// `swr_convert()` will likely not fill convertedAVFrame with that many
// samples if sample rate conversion is needed. It will buffer the last few
// ones because those require future samples. That's also why we reset
// nb_samples after the call to `swr_convert()`.
// We could also use `swr_get_out_samples()` to determine the number of
// output samples, but empirically `av_rescale_rnd()` seems to provide a
// tighter bound.
convertedAVFrame->nb_samples = av_rescale_rnd(
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) +
srcAVFrame->nb_samples,
desiredSampleRate,
sourceSampleRate,
AV_ROUND_UP);
} else {
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
}

auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
TORCH_CHECK(
status == AVSUCCESS,
"Could not allocate frame buffers for sample format conversion: ",
getFFMPEGErrorStringFromErrorCode(status));

auto numSampleConverted = swr_convert(
auto numConvertedSamples = swr_convert(
streamInfo.swrContext.get(),
convertedAVFrame->data,
convertedAVFrame->nb_samples,
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)),
avFrame->nb_samples);
static_cast<const uint8_t**>(
const_cast<const uint8_t**>(srcAVFrame->data)),
srcAVFrame->nb_samples);
TORCH_CHECK(
numSampleConverted > 0,
numConvertedSamples > 0,
"Error in swr_convert: ",
getFFMPEGErrorStringFromErrorCode(numSampleConverted));
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));

// See comment above about nb_samples
convertedAVFrame->nb_samples = numConvertedSamples;

return convertedAVFrame;
}

std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers() {
// When sample rate conversion is involved, swresample buffers some of the
// samples in-between calls to swr_convert (see the libswresample docs).
// That's because the last few samples in a given frame require future samples
// from the next frame to be properly converted. This function flushes out the
// samples that are stored in swresample's buffers.
auto& streamInfo = streamInfos_[activeStreamIndex_];
if (!streamInfo.swrContext) {
return std::nullopt;
}
auto numRemainingSamples = // this is an upper bound
swr_get_out_samples(streamInfo.swrContext.get(), 0);

if (numRemainingSamples == 0) {
return std::nullopt;
}

torch::Tensor lastSamples = torch::empty(
{getNumChannels(streamInfo.codecContext), numRemainingSamples},
torch::kFloat32);
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr());

auto actualNumRemainingSamples = swr_convert(
streamInfo.swrContext.get(),
&lastSamplesData,
numRemainingSamples,
nullptr,
0);
return lastSamples.narrow(
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
}

// --------------------------------------------------------------------------
// OUTPUT ALLOCATION AND SHAPE CONVERSION
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -1703,14 +1772,16 @@ void VideoDecoder::createSwsContext(

void VideoDecoder::createSwrContext(
StreamInfo& streamInfo,
int sampleRate,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat) {
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
auto swrContext = allocateSwrContext(
streamInfo.codecContext,
sampleRate,
sourceSampleFormat,
desiredSampleFormat);
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);

auto status = swr_init(swrContext);
TORCH_CHECK(
Expand Down
Loading
Loading