-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow sample_rate
parameter to audio decoder
#551
Changes from 15 commits
f6a7f4e
179a01c
2d97555
9af4bc8
2adf496
ef93be4
db740a6
ca15232
6aa7b09
f858d0c
70ac31e
8deb079
7b09315
af4e88a
975b0fb
7cb2271
ee1c7b7
bf9aed2
f0e2cdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -546,14 +546,18 @@ void VideoDecoder::addVideoStream( | |
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); | ||
} | ||
|
||
void VideoDecoder::addAudioStream(int streamIndex) { | ||
void VideoDecoder::addAudioStream( | ||
int streamIndex, | ||
const AudioStreamOptions& audioStreamOptions) { | ||
TORCH_CHECK( | ||
seekMode_ == SeekMode::approximate, | ||
"seek_mode must be 'approximate' for audio streams."); | ||
|
||
addStream(streamIndex, AVMEDIA_TYPE_AUDIO); | ||
|
||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
streamInfo.audioStreamOptions = audioStreamOptions; | ||
|
||
auto& streamMetadata = | ||
containerMetadata_.allStreamMetadata[activeStreamIndex_]; | ||
streamMetadata.sampleRate = | ||
|
@@ -913,6 +917,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( | |
(stopPts <= lastDecodedAvFrameEnd); | ||
} | ||
|
||
torch::Tensor lastSamples = maybeFlushSwrBuffers(); | ||
if (lastSamples.numel() > 0) { | ||
frames.push_back(lastSamples); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not particularly fond of the above. Maybe we could let There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything I can think of that does things unconditionally in this function is way too cute (returning a vector of tensors; using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using optional sounds better, thanks |
||
|
||
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds}; | ||
} | ||
|
||
|
@@ -1166,8 +1175,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( | |
getDuration(avFrame), | ||
formatContext_->streams[activeStreamIndex_]->time_base); | ||
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { | ||
convertAudioAVFrameToFrameOutputOnCPU( | ||
avFrame, frameOutput, preAllocatedOutputTensor); | ||
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); | ||
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { | ||
convertAVFrameToFrameOutputOnCPU( | ||
avFrame, frameOutput, preAllocatedOutputTensor); | ||
|
@@ -1345,24 +1353,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( | |
|
||
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( | ||
UniqueAVFrame& srcAVFrame, | ||
FrameOutput& frameOutput, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
TORCH_CHECK( | ||
!preAllocatedOutputTensor.has_value(), | ||
"pre-allocated audio tensor not supported yet."); | ||
|
||
FrameOutput& frameOutput) { | ||
AVSampleFormat sourceSampleFormat = | ||
static_cast<AVSampleFormat>(srcAVFrame->format); | ||
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; | ||
|
||
int sourceSampleRate = srcAVFrame->sample_rate; | ||
int desiredSampleRate = | ||
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or( | ||
sourceSampleRate); | ||
|
||
bool mustConvert = | ||
(sourceSampleFormat != desiredSampleFormat || | ||
sourceSampleRate != desiredSampleRate); | ||
|
||
UniqueAVFrame convertedAVFrame; | ||
if (sourceSampleFormat != desiredSampleFormat) { | ||
convertedAVFrame = convertAudioAVFrameSampleFormat( | ||
srcAVFrame, sourceSampleFormat, desiredSampleFormat); | ||
if (mustConvert) { | ||
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( | ||
srcAVFrame, | ||
sourceSampleFormat, | ||
desiredSampleFormat, | ||
sourceSampleRate, | ||
desiredSampleRate); | ||
} | ||
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat) | ||
? convertedAVFrame | ||
: srcAVFrame; | ||
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; | ||
|
||
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format); | ||
TORCH_CHECK( | ||
|
@@ -1385,55 +1399,107 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( | |
memcpy( | ||
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel); | ||
} | ||
|
||
frameOutput.data = outputData; | ||
} | ||
|
||
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat( | ||
const UniqueAVFrame& avFrame, | ||
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate( | ||
const UniqueAVFrame& srcAVFrame, | ||
AVSampleFormat sourceSampleFormat, | ||
AVSampleFormat desiredSampleFormat | ||
|
||
) { | ||
AVSampleFormat desiredSampleFormat, | ||
int sourceSampleRate, | ||
int desiredSampleRate) { | ||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
const auto& streamMetadata = | ||
containerMetadata_.allStreamMetadata[activeStreamIndex_]; | ||
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value()); | ||
|
||
if (!streamInfo.swrContext) { | ||
createSwrContext( | ||
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat); | ||
streamInfo, | ||
sourceSampleFormat, | ||
desiredSampleFormat, | ||
sourceSampleRate, | ||
desiredSampleRate); | ||
} | ||
|
||
UniqueAVFrame convertedAVFrame(av_frame_alloc()); | ||
TORCH_CHECK( | ||
convertedAVFrame, | ||
"Could not allocate frame for sample format conversion."); | ||
|
||
setChannelLayout(convertedAVFrame, avFrame); | ||
setChannelLayout(convertedAVFrame, srcAVFrame); | ||
convertedAVFrame->format = static_cast<int>(desiredSampleFormat); | ||
convertedAVFrame->sample_rate = avFrame->sample_rate; | ||
convertedAVFrame->nb_samples = avFrame->nb_samples; | ||
convertedAVFrame->sample_rate = desiredSampleRate; | ||
if (sourceSampleRate != desiredSampleRate) { | ||
// Note that this is an upper bound on the number of output samples. | ||
// `swr_convert()` will likely not fill convertedAVFrame with that many | ||
// samples if sample rate conversion is needed. It will buffer the last few | ||
// ones because those require future samples. That's also why we reset | ||
// nb_samples after the call to `swr_convert()`. | ||
convertedAVFrame->nb_samples = av_rescale_rnd( | ||
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) + | ||
srcAVFrame->nb_samples, | ||
desiredSampleRate, | ||
sourceSampleRate, | ||
AV_ROUND_UP); | ||
} else { | ||
convertedAVFrame->nb_samples = srcAVFrame->nb_samples; | ||
} | ||
|
||
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); | ||
TORCH_CHECK( | ||
status == AVSUCCESS, | ||
"Could not allocate frame buffers for sample format conversion: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
auto numSampleConverted = swr_convert( | ||
auto numConvertedSamples = swr_convert( | ||
streamInfo.swrContext.get(), | ||
convertedAVFrame->data, | ||
convertedAVFrame->nb_samples, | ||
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)), | ||
avFrame->nb_samples); | ||
static_cast<const uint8_t**>( | ||
const_cast<const uint8_t**>(srcAVFrame->data)), | ||
srcAVFrame->nb_samples); | ||
TORCH_CHECK( | ||
numSampleConverted > 0, | ||
numConvertedSamples > 0, | ||
"Error in swr_convert: ", | ||
getFFMPEGErrorStringFromErrorCode(numSampleConverted)); | ||
getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); | ||
|
||
// See comment above about nb_samples | ||
convertedAVFrame->nb_samples = numConvertedSamples; | ||
|
||
return convertedAVFrame; | ||
} | ||
|
||
torch::Tensor VideoDecoder::maybeFlushSwrBuffers() { | ||
// When sample rate conversion is involved, swresample buffers some of the | ||
// samples in-between calls to swr_convert (see the libswresample docs). | ||
// That's because the last few samples in a given frame require future samples | ||
// from the next frame to be properly converted. This function flushes out the | ||
// samples that are stored in swresample's buffers. | ||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
if (!streamInfo.swrContext) { | ||
return torch::empty({0, 0}); | ||
} | ||
auto numRemainingSamples = // this is an upper bound | ||
swr_get_out_samples(streamInfo.swrContext.get(), 0); | ||
|
||
if (numRemainingSamples == 0) { | ||
return torch::empty({0, 0}); | ||
} | ||
|
||
torch::Tensor lastSamples = torch::empty( | ||
{getNumChannels(streamInfo.codecContext), numRemainingSamples}, | ||
torch::kFloat32); | ||
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr()); | ||
|
||
auto actualNumRemainingSamples = swr_convert( | ||
streamInfo.swrContext.get(), | ||
&lastSamplesData, | ||
numRemainingSamples, | ||
NULL, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/ |
||
0); | ||
return lastSamples.narrow( | ||
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); | ||
} | ||
|
||
// -------------------------------------------------------------------------- | ||
// OUTPUT ALLOCATION AND SHAPE CONVERSION | ||
// -------------------------------------------------------------------------- | ||
|
@@ -1669,14 +1735,16 @@ void VideoDecoder::createSwsContext( | |
|
||
void VideoDecoder::createSwrContext( | ||
StreamInfo& streamInfo, | ||
int sampleRate, | ||
AVSampleFormat sourceSampleFormat, | ||
AVSampleFormat desiredSampleFormat) { | ||
AVSampleFormat desiredSampleFormat, | ||
int sourceSampleRate, | ||
int desiredSampleRate) { | ||
auto swrContext = allocateSwrContext( | ||
streamInfo.codecContext, | ||
sampleRate, | ||
sourceSampleFormat, | ||
desiredSampleFormat); | ||
desiredSampleFormat, | ||
sourceSampleRate, | ||
desiredSampleRate); | ||
|
||
auto status = swr_init(swrContext); | ||
TORCH_CHECK( | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -139,10 +139,18 @@ class VideoDecoder { | |||||||||
torch::Device device = torch::kCPU; | ||||||||||
}; | ||||||||||
|
||||||||||
struct AudioStreamOptions { | ||||||||||
AudioStreamOptions() {} | ||||||||||
|
||||||||||
std::optional<int> sampleRate; | ||||||||||
}; | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need the indirection of having an options struct. I know it mirrors the pattern established on the video side, but I also don't think it's a good practice there, either. It's harder to get rid of the video options because we accept a string, and then we do a bunch of work parsing the string. Getting rid of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of the stringy video options either - that's why I didn't implement a string constructor for audio options. I also opened #577 to entirely remove string video options (it's not as hard as we thought). I don't feel very strongly about this, but if we were to collapse both video options and audio options into the video decoder (or the StreamInfo), then we would have a lot of video-only and audio-only fields within the same struct. I personally find it cleaner to separate those into separate structs. Additionally, using option structs makes it very clear which fields/values come from user-specified parameter, in constrast to e.g. metadata or video properties, which is often useful to immediately understand the source of the values, as e.g. here: torchcodec/src/torchcodec/decoders/_core/VideoDecoder.cpp Lines 1395 to 1398 in 7cb2271
LMK you thoughts, I'm fine with collapsing sampleRate as a StreamInfo field if you prefer, but we'd potentially be losing the 2 benefits mentioned above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's a good point, we actually store the options. I was thinking just from the |
||||||||||
void addVideoStream( | ||||||||||
int streamIndex, | ||||||||||
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions()); | ||||||||||
void addAudioStream(int streamIndex); | ||||||||||
void addAudioStream( | ||||||||||
int streamIndex, | ||||||||||
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); | ||||||||||
|
||||||||||
// -------------------------------------------------------------------------- | ||||||||||
// DECODING AND SEEKING APIs | ||||||||||
|
@@ -336,6 +344,7 @@ class VideoDecoder { | |||||||||
int64_t lastDecodedAvFramePts = 0; | ||||||||||
int64_t lastDecodedAvFrameDuration = 0; | ||||||||||
VideoStreamOptions videoStreamOptions; | ||||||||||
AudioStreamOptions audioStreamOptions; | ||||||||||
|
||||||||||
// color-conversion fields. Only one of FilterGraphContext and | ||||||||||
// UniqueSwsContext should be non-null. | ||||||||||
|
@@ -382,8 +391,7 @@ class VideoDecoder { | |||||||||
|
||||||||||
void convertAudioAVFrameToFrameOutputOnCPU( | ||||||||||
UniqueAVFrame& srcAVFrame, | ||||||||||
FrameOutput& frameOutput, | ||||||||||
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||||||||||
FrameOutput& frameOutput); | ||||||||||
|
||||||||||
torch::Tensor convertAVFrameToTensorUsingFilterGraph( | ||||||||||
const UniqueAVFrame& avFrame); | ||||||||||
|
@@ -392,10 +400,14 @@ class VideoDecoder { | |||||||||
const UniqueAVFrame& avFrame, | ||||||||||
torch::Tensor& outputTensor); | ||||||||||
|
||||||||||
UniqueAVFrame convertAudioAVFrameSampleFormat( | ||||||||||
const UniqueAVFrame& avFrame, | ||||||||||
UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( | ||||||||||
const UniqueAVFrame& srcAVFrame, | ||||||||||
AVSampleFormat sourceSampleFormat, | ||||||||||
AVSampleFormat desiredSampleFormat); | ||||||||||
AVSampleFormat desiredSampleFormat, | ||||||||||
int sourceSampleRate, | ||||||||||
int desiredSampleRate); | ||||||||||
|
||||||||||
torch::Tensor maybeFlushSwrBuffers(); | ||||||||||
|
||||||||||
// -------------------------------------------------------------------------- | ||||||||||
// COLOR CONVERSION LIBRARIES HANDLERS CREATION | ||||||||||
|
@@ -413,9 +425,10 @@ class VideoDecoder { | |||||||||
|
||||||||||
void createSwrContext( | ||||||||||
StreamInfo& streamInfo, | ||||||||||
int sampleRate, | ||||||||||
AVSampleFormat sourceSampleFormat, | ||||||||||
AVSampleFormat desiredSampleFormat); | ||||||||||
AVSampleFormat desiredSampleFormat, | ||||||||||
int sourceSampleRate, | ||||||||||
int desiredSampleRate); | ||||||||||
|
||||||||||
// -------------------------------------------------------------------------- | ||||||||||
// PTS <-> INDEX CONVERSIONS | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also consider exposing this as
desired_sample_rate
parameter? I don't have a strong opinion, the docs would make it clear what this means in any case.