Skip to content

Commit

Permalink
FFT: new stride logic
Browse files Browse the repository at this point in the history
  • Loading branch information
pschatzmann committed Feb 20, 2025
1 parent d9b9399 commit d44d0f5
Showing 1 changed file with 71 additions and 60 deletions.
131 changes: 71 additions & 60 deletions src/AudioTools/AudioLibs/AudioFFT.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ struct AudioFFTConfig : public AudioInfo {
uint8_t channel_used = 0;
int length = 8192;
int stride = 0;
/// Optional window function
/// Optional window function for both fft and ifft
WindowFunction *window_function = nullptr;
/// Optional window function for fft only
WindowFunction *window_function_fft = nullptr;
/// Optional window function for ifft only
WindowFunction *window_function_ifft = nullptr;
/// TX_MODE = FFT, RX_MODE = IFFT
RxTxMode rxtx_mode = TX_MODE;
};
Expand Down Expand Up @@ -96,18 +100,19 @@ class FFTInverseOverlapAdder {
// adds the values to the array (by applying the window function)
void add(float value, int pos, WindowFunction *window_function) {
float add_value = value;
if (window_function != nullptr)
if (window_function != nullptr) {
add_value = value * window_function->factor(pos);
}
data[pos] += add_value;
}

// gets the scaled audio data as result
void getStepData(float *result, int step_size, float maxResult) {
for (int j = 0; j < step_size; j++) {
void getStepData(float *result, int stride, float maxResult) {
for (int j = 0; j < stride; j++) {
// determine max value to scale
if (rfft_max < data[j]) rfft_max = data[j];
if (data[j] > rfft_max) rfft_max = data[j];
}
for (int j = 0; j < step_size; j++) {
for (int j = 0; j < stride; j++) {
result[j] = data[j] / rfft_max * maxResult;
// clip
if (result[j] > maxResult) {
Expand All @@ -118,17 +123,21 @@ class FFTInverseOverlapAdder {
}
}
// copy data to head
for (int j = 0; j < len - step_size; j++) {
data[j] = data[j + step_size];
for (int j = 0; j < len - stride; j++) {
data[j] = data[j + stride];
}
// clear tail
for (int j = len - step_size; j < len; j++) {
for (int j = len - stride; j < len; j++) {
data[j] = 0.0;
}
}

int size() { return data.size(); }

void begin(){
rfft_max = 0;
}

protected:
Vector<float> data{0};
int len = 0;
Expand Down Expand Up @@ -200,24 +209,32 @@ class AudioFFTBase : public AudioStream {
/// starts the processing
bool begin() override {
bins = cfg.length / 2;
// define window functions
if (cfg.window_function_fft==nullptr) cfg.window_function_fft = cfg.window_function;
if (cfg.window_function_ifft==nullptr) cfg.window_function_ifft = cfg.window_function;
// define default stride value if not defined
if (cfg.stride == 0) cfg.stride = cfg.length;

if (!isPowerOfTwo(cfg.length)) {
LOGE("Len must be of the power of 2: %d", cfg.length);
return false;
}
if (!p_driver->begin(cfg.length)) {
LOGE("Not enough memory");
}
int step_size = cfg.stride > 0 ? cfg.stride : cfg.length;
if (cfg.window_function != nullptr) {
cfg.window_function->begin(step_size);

if (cfg.window_function_fft != nullptr) {
cfg.window_function_fft->begin(cfg.length);
}
if (cfg.window_function_ifft != nullptr) {
cfg.window_function_ifft->begin(cfg.length);
}

int step_size = cfg.stride > 0 ? cfg.stride : cfg.length;
bool is_valid_rxtx = false;
if (cfg.rxtx_mode == TX_MODE || cfg.rxtx_mode == RXTX_MODE) {
if (cfg.stride > 0 && cfg.stride < cfg.length) {
// holds last N bytes that need to be reprocessed
stride_buffer.resize((cfg.length - cfg.stride) * bytesPerSample());
}
// holds last N bytes that need to be reprocessed
stride_buffer.resize((cfg.length) * bytesPerSample());
is_valid_rxtx = true;
}
if (cfg.rxtx_mode == RX_MODE || cfg.rxtx_mode == RXTX_MODE) {
Expand All @@ -238,8 +255,11 @@ class AudioFFTBase : public AudioStream {
/// Just resets the current_pos e.g. to start a new cycle
void reset() {
current_pos = 0;
if (cfg.window_function != nullptr) {
cfg.window_function->begin(length());
if (cfg.window_function_fft != nullptr) {
cfg.window_function_fft->begin(cfg.length);
}
if (cfg.window_function_ifft != nullptr) {
cfg.window_function_ifft->begin(cfg.length);
}
}

Expand Down Expand Up @@ -290,7 +310,7 @@ class AudioFFTBase : public AudioStream {

/// Provides the result of a reverse FFT
size_t readBytes(uint8_t *data, size_t len) override {
LOGD("setup ifft data");
TRACED();
if (rfft_data.size() == 0) return 0;
// execute rfft when we consumed all data
if (has_rfft_data && rfft_data.available() == 0) {
Expand Down Expand Up @@ -432,7 +452,7 @@ class AudioFFTBase : public AudioStream {
AudioFFTConfig cfg;
unsigned long timestamp_begin = 0l;
unsigned long timestamp = 0l;
RingBuffer<uint8_t> stride_buffer{0};
SingleBuffer<uint8_t> stride_buffer{0};
Vector<float> l_magnitudes{0};
Vector<float> step_data{0};
int bins = 0;
Expand All @@ -445,35 +465,35 @@ class AudioFFTBase : public AudioStream {
void processSamples(const void *data, size_t samples) {
T *dataT = (T *)data;
T sample;
float sample_windowed;
for (int j = 0; j < samples; j += cfg.channels) {
sample = dataT[j + cfg.channel_used];
p_driver->setValue(current_pos, windowedSample(sample, current_pos));
writeStrideBuffer((uint8_t *)&sample, sizeof(T));
if (++current_pos >= cfg.length) {
fft<T>();
current_pos = 0;

// reprocess data in stride buffer
if (stride_buffer.size() > 0) {
// reload data from stride buffer
while (stride_buffer.available()) {
T sample;
stride_buffer.readArray((uint8_t *)&sample, sizeof(T));
p_driver->setValue(current_pos,
windowedSample(sample, current_pos));
current_pos++;
}
if (writeStrideBuffer((uint8_t *)&sample, sizeof(T))){
// process data if buffer is full
T* samples = (T*) stride_buffer.data();
int sample_count = stride_buffer.size() / sizeof(T);
assert(sample_count == cfg.length);
for (int j=0; j< sample_count; j++){
T out_sample = samples[j];
p_driver->setValue(j, windowedSample(out_sample, j));
}

fft<T>();

// remove stride samples
stride_buffer.clearArray(cfg.stride * sizeof(T));

// validate available data in stride buffer
if (cfg.stride == cfg.length) assert(stride_buffer.available()==0);

}
}
}

template <typename T>
T windowedSample(T sample, int pos) {
T result = sample;
if (cfg.window_function != nullptr) {
result = cfg.window_function->factor(pos) * sample;
if (cfg.window_function_fft != nullptr) {
result = cfg.window_function_fft->factor(pos) * sample;
}
return result;
}
Expand All @@ -498,35 +518,34 @@ class AudioFFTBase : public AudioStream {
// add data to sum buffer
for (int j = 0; j < cfg.length; j++) {
float value = p_driver->getValue(j);
rfft_add.add(value, j, cfg.window_function);
rfft_add.add(value, j, cfg.window_function_ifft);
}
// get result data from sum buffer
rfftWriteData(rfft_data);
}

/// write reverse fft result to buffer to make it available for readBytes
void rfftWriteData(BaseBuffer<uint8_t> &data) {
int step_size = cfg.stride > 0 ? cfg.stride : cfg.length;
// get data to result buffer
step_data.resize(step_size);
for (int j = 0; j < step_size; j++) {
step_data.resize(cfg.stride);
for (int j = 0; j < cfg.stride; j++) {
step_data[j] = 0.0;
}
rfft_add.getStepData(step_data.data(), step_size,
rfft_add.getStepData(step_data.data(), cfg.stride,
NumberConverter::maxValue(cfg.bits_per_sample));

switch (cfg.bits_per_sample) {
case 8:
writeIFFT<int8_t>(step_data.data(), step_size);
writeIFFT<int8_t>(step_data.data(), cfg.stride);
break;
case 16:
writeIFFT<int16_t>(step_data.data(), step_size);
writeIFFT<int16_t>(step_data.data(), cfg.stride);
break;
case 24:
writeIFFT<int24_t>(step_data.data(), step_size);
writeIFFT<int24_t>(step_data.data(), cfg.stride);
break;
case 32:
writeIFFT<int32_t>(step_data.data(), step_size);
writeIFFT<int32_t>(step_data.data(), cfg.stride);
break;
default:
LOGE("Unsupported bits: %d", cfg.bits_per_sample);
Expand Down Expand Up @@ -567,18 +586,10 @@ class AudioFFTBase : public AudioStream {
}
}

void writeStrideBuffer(uint8_t *buffer, size_t len) {
if (stride_buffer.size() > 0) {
int available = stride_buffer.availableForWrite();
if (len > available) {
// clear oldest values to make space
int diff = len - available;
for (int j = 0; j < diff; j++) {
stride_buffer.read();
}
}
stride_buffer.writeArray(buffer, len);
}
// adds samples to stride buffer, returns true if the buffer is full
bool writeStrideBuffer(uint8_t *buffer, size_t len) {
stride_buffer.writeArray(buffer, len);
return stride_buffer.isFull();
}

bool isPowerOfTwo(uint16_t x) { return (x & (x - 1)) == 0; }
Expand Down

0 comments on commit d44d0f5

Please sign in to comment.