-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow sample_rate
parameter to audio decoder
#551
Changes from all commits
f6a7f4e
179a01c
2d97555
9af4bc8
2adf496
ef93be4
db740a6
ca15232
6aa7b09
f858d0c
70ac31e
8deb079
7b09315
af4e88a
975b0fb
7cb2271
ee1c7b7
bf9aed2
f0e2cdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -580,14 +580,18 @@ void VideoDecoder::addVideoStream( | |
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); | ||
} | ||
|
||
void VideoDecoder::addAudioStream(int streamIndex) { | ||
void VideoDecoder::addAudioStream( | ||
int streamIndex, | ||
const AudioStreamOptions& audioStreamOptions) { | ||
TORCH_CHECK( | ||
seekMode_ == SeekMode::approximate, | ||
"seek_mode must be 'approximate' for audio streams."); | ||
|
||
addStream(streamIndex, AVMEDIA_TYPE_AUDIO); | ||
|
||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
streamInfo.audioStreamOptions = audioStreamOptions; | ||
|
||
auto& streamMetadata = | ||
containerMetadata_.allStreamMetadata[activeStreamIndex_]; | ||
streamMetadata.sampleRate = | ||
|
@@ -947,6 +951,11 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( | |
(stopPts <= lastDecodedAvFrameEnd); | ||
} | ||
|
||
auto lastSamples = maybeFlushSwrBuffers(); | ||
if (lastSamples.has_value()) { | ||
frames.push_back(*lastSamples); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not particularly fond of the above. Maybe we could let There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Everything I can think of that does things unconditionally in this function is way too cute (returning a vector of tensors; using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using optional sounds better, thanks |
||
|
||
return AudioFramesOutput{torch::cat(frames, 1), firstFramePtsSeconds}; | ||
} | ||
|
||
|
@@ -1200,8 +1209,7 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput( | |
getDuration(avFrame), | ||
formatContext_->streams[activeStreamIndex_]->time_base); | ||
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { | ||
convertAudioAVFrameToFrameOutputOnCPU( | ||
avFrame, frameOutput, preAllocatedOutputTensor); | ||
convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); | ||
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) { | ||
convertAVFrameToFrameOutputOnCPU( | ||
avFrame, frameOutput, preAllocatedOutputTensor); | ||
|
@@ -1379,24 +1387,30 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( | |
|
||
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( | ||
UniqueAVFrame& srcAVFrame, | ||
FrameOutput& frameOutput, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
TORCH_CHECK( | ||
!preAllocatedOutputTensor.has_value(), | ||
"pre-allocated audio tensor not supported yet."); | ||
|
||
FrameOutput& frameOutput) { | ||
AVSampleFormat sourceSampleFormat = | ||
static_cast<AVSampleFormat>(srcAVFrame->format); | ||
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; | ||
|
||
int sourceSampleRate = srcAVFrame->sample_rate; | ||
int desiredSampleRate = | ||
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or( | ||
sourceSampleRate); | ||
|
||
bool mustConvert = | ||
(sourceSampleFormat != desiredSampleFormat || | ||
sourceSampleRate != desiredSampleRate); | ||
|
||
UniqueAVFrame convertedAVFrame; | ||
if (sourceSampleFormat != desiredSampleFormat) { | ||
convertedAVFrame = convertAudioAVFrameSampleFormat( | ||
srcAVFrame, sourceSampleFormat, desiredSampleFormat); | ||
if (mustConvert) { | ||
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( | ||
srcAVFrame, | ||
sourceSampleFormat, | ||
desiredSampleFormat, | ||
sourceSampleRate, | ||
desiredSampleRate); | ||
} | ||
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat) | ||
? convertedAVFrame | ||
: srcAVFrame; | ||
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; | ||
|
||
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format); | ||
TORCH_CHECK( | ||
|
@@ -1419,55 +1433,110 @@ void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU( | |
memcpy( | ||
outputChannelData, avFrame->extended_data[channel], numBytesPerChannel); | ||
} | ||
|
||
frameOutput.data = outputData; | ||
} | ||
|
||
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormat( | ||
const UniqueAVFrame& avFrame, | ||
UniqueAVFrame VideoDecoder::convertAudioAVFrameSampleFormatAndSampleRate( | ||
const UniqueAVFrame& srcAVFrame, | ||
AVSampleFormat sourceSampleFormat, | ||
AVSampleFormat desiredSampleFormat | ||
|
||
) { | ||
AVSampleFormat desiredSampleFormat, | ||
int sourceSampleRate, | ||
int desiredSampleRate) { | ||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
const auto& streamMetadata = | ||
containerMetadata_.allStreamMetadata[activeStreamIndex_]; | ||
int sampleRate = static_cast<int>(streamMetadata.sampleRate.value()); | ||
|
||
if (!streamInfo.swrContext) { | ||
createSwrContext( | ||
streamInfo, sampleRate, sourceSampleFormat, desiredSampleFormat); | ||
streamInfo, | ||
sourceSampleFormat, | ||
desiredSampleFormat, | ||
sourceSampleRate, | ||
desiredSampleRate); | ||
} | ||
|
||
UniqueAVFrame convertedAVFrame(av_frame_alloc()); | ||
TORCH_CHECK( | ||
convertedAVFrame, | ||
"Could not allocate frame for sample format conversion."); | ||
|
||
setChannelLayout(convertedAVFrame, avFrame); | ||
setChannelLayout(convertedAVFrame, srcAVFrame); | ||
convertedAVFrame->format = static_cast<int>(desiredSampleFormat); | ||
convertedAVFrame->sample_rate = avFrame->sample_rate; | ||
convertedAVFrame->nb_samples = avFrame->nb_samples; | ||
convertedAVFrame->sample_rate = desiredSampleRate; | ||
if (sourceSampleRate != desiredSampleRate) { | ||
// Note that this is an upper bound on the number of output samples. | ||
// `swr_convert()` will likely not fill convertedAVFrame with that many | ||
// samples if sample rate conversion is needed. It will buffer the last few | ||
// ones because those require future samples. That's also why we reset | ||
// nb_samples after the call to `swr_convert()`. | ||
// We could also use `swr_get_out_samples()` to determine the number of | ||
// output samples, but empirically `av_rescale_rnd()` seems to provide a | ||
// tighter bound. | ||
convertedAVFrame->nb_samples = av_rescale_rnd( | ||
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) + | ||
srcAVFrame->nb_samples, | ||
desiredSampleRate, | ||
sourceSampleRate, | ||
AV_ROUND_UP); | ||
} else { | ||
convertedAVFrame->nb_samples = srcAVFrame->nb_samples; | ||
} | ||
|
||
auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); | ||
TORCH_CHECK( | ||
status == AVSUCCESS, | ||
"Could not allocate frame buffers for sample format conversion: ", | ||
getFFMPEGErrorStringFromErrorCode(status)); | ||
|
||
auto numSampleConverted = swr_convert( | ||
auto numConvertedSamples = swr_convert( | ||
streamInfo.swrContext.get(), | ||
convertedAVFrame->data, | ||
convertedAVFrame->nb_samples, | ||
static_cast<const uint8_t**>(const_cast<const uint8_t**>(avFrame->data)), | ||
avFrame->nb_samples); | ||
static_cast<const uint8_t**>( | ||
const_cast<const uint8_t**>(srcAVFrame->data)), | ||
srcAVFrame->nb_samples); | ||
TORCH_CHECK( | ||
numSampleConverted > 0, | ||
numConvertedSamples > 0, | ||
"Error in swr_convert: ", | ||
getFFMPEGErrorStringFromErrorCode(numSampleConverted)); | ||
getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); | ||
|
||
// See comment above about nb_samples | ||
convertedAVFrame->nb_samples = numConvertedSamples; | ||
|
||
return convertedAVFrame; | ||
} | ||
|
||
std::optional<torch::Tensor> VideoDecoder::maybeFlushSwrBuffers() { | ||
// When sample rate conversion is involved, swresample buffers some of the | ||
// samples in-between calls to swr_convert (see the libswresample docs). | ||
// That's because the last few samples in a given frame require future samples | ||
// from the next frame to be properly converted. This function flushes out the | ||
// samples that are stored in swresample's buffers. | ||
auto& streamInfo = streamInfos_[activeStreamIndex_]; | ||
if (!streamInfo.swrContext) { | ||
return std::nullopt; | ||
} | ||
auto numRemainingSamples = // this is an upper bound | ||
swr_get_out_samples(streamInfo.swrContext.get(), 0); | ||
|
||
if (numRemainingSamples == 0) { | ||
return std::nullopt; | ||
} | ||
|
||
torch::Tensor lastSamples = torch::empty( | ||
{getNumChannels(streamInfo.codecContext), numRemainingSamples}, | ||
torch::kFloat32); | ||
uint8_t* lastSamplesData = static_cast<uint8_t*>(lastSamples.data_ptr()); | ||
|
||
auto actualNumRemainingSamples = swr_convert( | ||
streamInfo.swrContext.get(), | ||
&lastSamplesData, | ||
numRemainingSamples, | ||
nullptr, | ||
0); | ||
return lastSamples.narrow( | ||
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); | ||
} | ||
|
||
// -------------------------------------------------------------------------- | ||
// OUTPUT ALLOCATION AND SHAPE CONVERSION | ||
// -------------------------------------------------------------------------- | ||
|
@@ -1703,14 +1772,16 @@ void VideoDecoder::createSwsContext( | |
|
||
void VideoDecoder::createSwrContext( | ||
StreamInfo& streamInfo, | ||
int sampleRate, | ||
AVSampleFormat sourceSampleFormat, | ||
AVSampleFormat desiredSampleFormat) { | ||
AVSampleFormat desiredSampleFormat, | ||
int sourceSampleRate, | ||
int desiredSampleRate) { | ||
auto swrContext = allocateSwrContext( | ||
streamInfo.codecContext, | ||
sampleRate, | ||
sourceSampleFormat, | ||
desiredSampleFormat); | ||
desiredSampleFormat, | ||
sourceSampleRate, | ||
desiredSampleRate); | ||
|
||
auto status = swr_init(swrContext); | ||
TORCH_CHECK( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also consider exposing this as
desired_sample_rate
parameter? I don't have a strong opinion, the docs would make it clear what this means in any case.