diff --git a/AecSample.cpp b/AecSample.cpp index c99fd4b..b7a84b9 100644 --- a/AecSample.cpp +++ b/AecSample.cpp @@ -1,72 +1,79 @@ - - +#include #include -#include "DTLN_AEC.h" - -int main(int argc, char *argv[]) -{ - //lpszInputRefWave is a reference data(far end) - //lpszInputRecWave is a recording data(near end recording) - - std::string lpszInputRefWave = std::string(argv[1]); - std::string lpszInputRecWave = std::string(argv[2]); - std::string lpszOutputWave = std::string(argv[3]); - - FILE *lpoInputRefFile = NULL; - FILE *lpoInputRecFile = NULL; - FILE *lpoOutputFile = NULL; - - short *lpsInputRefSample = NULL; - short *lpsInputRecSample = NULL; - short *lpsOutputSample = NULL; - - int nReadSize, nFrameSize; - - lpoInputRefFile = fopen(lpszInputRefWave.c_str(), "rb"); - lpoInputRecFile = fopen(lpszInputRecWave.c_str(), "rb"); - lpoOutputFile = fopen(lpszOutputWave.c_str(), "wb+"); - - DTLN_AEC oDtlnAec; +#include - nFrameSize = oDtlnAec.Init(); - - lpsInputRefSample = new short[nFrameSize]; - lpsInputRecSample = new short[nFrameSize]; - lpsOutputSample = new short[nFrameSize]; - - //Skip wave header - fread(lpsInputRefSample, 1, 44, lpoInputRefFile); - fread(lpsInputRecSample, 1, 44, lpoInputRecFile); - - while (true) - { - nReadSize = fread(lpsInputRefSample, 1, nFrameSize * sizeof(short), lpoInputRefFile); - if (nReadSize <= 0) - break; - - nReadSize = fread(lpsInputRecSample, 1, nFrameSize * sizeof(short), lpoInputRecFile); - if (nReadSize <= 0) - break; - - oDtlnAec.Process(lpsInputRefSample, lpsInputRecSample, lpsOutputSample); +#include "DTLN_AEC.h" - //write PCM - fwrite(lpsOutputSample, 1, nFrameSize * sizeof(short), lpoOutputFile); +namespace { +constexpr int kWaveHeaderSize = 44; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::fprintf(stderr, + "Usage: %s \n", + argv[0]); + return 1; + } + + const std::string input_ref_wave(argv[1]); + const std::string input_rec_wave(argv[2]); + const std::string output_wave(argv[3]); + + FILE* input_ref_file = std::fopen(input_ref_wave.c_str(), "rb"); + FILE* input_rec_file = std::fopen(input_rec_wave.c_str(), "rb"); + FILE* output_file = std::fopen(output_wave.c_str(), "wb+"); + + if (input_ref_file == nullptr || input_rec_file == nullptr || + output_file == nullptr) { + std::fprintf(stderr, "Failed to open one or more files.\n"); + if (input_ref_file != nullptr) std::fclose(input_ref_file); + if (input_rec_file != nullptr) std::fclose(input_rec_file); + if (output_file != nullptr) std::fclose(output_file); + return 1; + } + + DTLN_AEC dtln_aec; + const int frame_size = dtln_aec.Init(); + if (frame_size <= 0) { + std::fprintf(stderr, "DTLN_AEC initialization failed.\n"); + std::fclose(input_ref_file); + std::fclose(input_rec_file); + std::fclose(output_file); + return 1; + } + + std::vector input_ref_sample(frame_size); + std::vector input_rec_sample(frame_size); + std::vector output_sample(frame_size); + + // Skip wave header. + std::fread(input_ref_sample.data(), 1, kWaveHeaderSize, input_ref_file); + std::fread(input_rec_sample.data(), 1, kWaveHeaderSize, input_rec_file); + + while (true) { + int read_size = std::fread(input_ref_sample.data(), 1, + frame_size * sizeof(short), input_ref_file); + if (read_size <= 0) { + break; } - fclose(lpoInputRefFile); - fclose(lpoInputRecFile); - fclose(lpoOutputFile); - - if (lpsInputRefSample != NULL) - delete[] lpsInputRefSample; + read_size = std::fread(input_rec_sample.data(), 1, + frame_size * sizeof(short), input_rec_file); + if (read_size <= 0) { + break; + } - if (lpsInputRecSample != NULL) - delete[] lpsInputRecSample; + dtln_aec.Process(input_ref_sample.data(), input_rec_sample.data(), + output_sample.data()); - if (lpsOutputSample != NULL) - delete[] lpsOutputSample; + std::fwrite(output_sample.data(), 1, frame_size * sizeof(short), + output_file); + } - return 0; + std::fclose(input_ref_file); + std::fclose(input_rec_file); + std::fclose(output_file); -} \ No newline at end of file + return 0; +} diff --git a/DTLN_AEC/DTLN_AEC.cpp b/DTLN_AEC/DTLN_AEC.cpp index f55f00a..d7dc742 100644 --- a/DTLN_AEC/DTLN_AEC.cpp +++ b/DTLN_AEC/DTLN_AEC.cpp @@ -1,442 +1,423 @@ - -#include + +// NOLINTBEGIN #include +#include + +#include +#include +#include #include "dtln_aec_128_1.h" #include "dtln_aec_128_2.h" -//Use KissFFT https://github.com/mborgerding/kissfft +// Use KissFFT https://github.com/mborgerding/kissfft +#include "DTLN_AEC.h" #include "kiss_fftr.h" +// NOLINTEND -#include "DTLN_AEC.h" +// 1 Network contain 2 models +// Please check : https://github.com/breizhn/DTLN-aec +// This code is translate from : +// https://github.com/breizhn/DTLN-aec/blob/main/run_aec.py -//1 Network contain 2 models -//Please check : https://github.com/breizhn/DTLN-aec -//This code is translate from : https://github.com/breizhn/DTLN-aec/blob/main/run_aec.py +// const param +constexpr int kWindowSize = 512; +constexpr int kWindowShift = 128; +constexpr int kFftForTensorSize = (kWindowSize / 2 + 1); -//const param -constexpr auto k_nWindowSize = 512; -constexpr auto k_nWindowShift = 128; -constexpr auto k_nFftForTensorSize = (k_nWindowSize / 2 + 1); +constexpr int kNumModels = 2; -constexpr auto k_nNumModels = 2; +constexpr int kNumThreads = 1; -constexpr auto k_nNumThreads = 1; +class DTLN_AEC::Impl { + public: + int Init(); + void Release(); -class DTLN_AEC::m_Impl -{ -public: - - int Init(void); - void Release(void); + int Process(short *ref_buffer, short *rec_buffer, short *output_buffer); + void AEC(); - int Process(short *lpsRefBuffer, short *lpsRecBuffer, short *lpsOutputBuffer); - void AEC(void); + TfLiteModel *tflite_models_[kNumModels]; + TfLiteInterpreter *interpreters_[kNumModels]; + TfLiteInterpreterOptions *interpreter_options_ = nullptr; - TfLiteModel *m_lppoTfliteModel[k_nNumModels]; - TfLiteInterpreter *m_lppoInterpreter[k_nNumModels]; - TfLiteInterpreterOptions *m_lpoInterpreterOptions = nullptr; + TfLiteTensor *input_tensors_[kNumModels][3]; + const TfLiteTensor *output_tensors_[kNumModels][2]; - TfLiteTensor *m_lppoInputTensor[k_nNumModels][3]; - const TfLiteTensor *m_lppoOutputTensor[k_nNumModels][2]; + // FFT + kiss_fftr_cfg fftr_cfg_ = nullptr; + kiss_fftr_cfg ifftr_cfg_ = nullptr; - //FFT - kiss_fftr_cfg m_lpoFftrCfg = nullptr; - kiss_fftr_cfg m_lpoIfftrCfg = nullptr; + kiss_fft_cpx *input_ref_cpx_ = nullptr; + kiss_fft_cpx *input_rec_cpx_ = nullptr; + kiss_fft_cpx *output_cpx_ = nullptr; - kiss_fft_cpx *m_lpoInputRefCpx = nullptr; - kiss_fft_cpx *m_lpoInputRecCpx = nullptr; - kiss_fft_cpx *m_lpoOutputCpx = nullptr; + // Internal buffer + float *input_ref_buffer_ = nullptr; + float *input_rec_buffer_ = nullptr; + float *output_buffer_ = nullptr; - //Internal buffer - float *m_lpfInputRefBuffer = nullptr; - float *m_lpfInputRecBuffer = nullptr; - float *m_lpfOutputBuffer = nullptr; + float *dtln_freq_output_ = nullptr; + float *dtln_time_output_ = nullptr; - float *m_lpfDtlnFreqOutput = nullptr; - float *m_lpfDtlnTimeOutput = nullptr; + int state_size_[kNumModels]; + float *states_[kNumModels]; - int m_lpnStateSize[k_nNumModels]; - float *m_lppfStates[k_nNumModels]; + float *input_ref_mag_ = nullptr; + float *input_ref_phase_ = nullptr; - float *m_lpfInputRefMag = nullptr; - float *m_lpfInputRefPhase = nullptr; - - float *m_lpfInputRecMag = nullptr; - float *m_lpfInputRecPhase = nullptr; + float *input_rec_mag_ = nullptr; + float *input_rec_phase_ = nullptr; - float *m_lpfEstimatedBlock = nullptr; + float *estimated_block_ = nullptr; - //Format change buffer - float *m_lpfInputRefSample = nullptr; - float *m_lpfInputRecSample = nullptr; - float *m_lpfOutputSample = nullptr; + // Format change buffer + float *input_ref_sample_ = nullptr; + float *input_rec_sample_ = nullptr; + float *output_sample_ = nullptr; - bool m_bInitSuccess = false; - + bool init_success_ = false; }; -int DTLN_AEC::m_Impl::Init(void) -{ - int nRet = -1; - - do - { - for (int i = 0; i < k_nNumModels; i++) - { - this->m_lppoTfliteModel[i] = nullptr; - this->m_lppoInterpreter[i] = nullptr; - - this->m_lppfStates[i] = nullptr; - } - - //Load models - this->m_lppoTfliteModel[0] = TfLiteModelCreate(k_lpszModel1Tflite, k_nModel1TfliteLen); - this->m_lppoTfliteModel[1] = TfLiteModelCreate(k_lpszModel2Tflite, k_nModel2TfliteLen); +int DTLN_AEC::Impl::Init() { + int ret = -1; - if (this->m_lppoTfliteModel[0] == nullptr || this->m_lppoTfliteModel[1] == nullptr) - break; + do { + for (int i = 0; i < kNumModels; i++) { + tflite_models_[i] = nullptr; + interpreters_[i] = nullptr; - //Create option - this->m_lpoInterpreterOptions = TfLiteInterpreterOptionsCreate(); - TfLiteInterpreterOptionsSetNumThreads(this->m_lpoInterpreterOptions, k_nNumThreads); + states_[i] = nullptr; + } + + // Load models + tflite_models_[0] = + TfLiteModelCreate(k_lpszModel1Tflite, k_nModel1TfliteLen); + tflite_models_[1] = + TfLiteModelCreate(k_lpszModel2Tflite, k_nModel2TfliteLen); - //Create the interpreter - this->m_lppoInterpreter[0] = TfLiteInterpreterCreate(this->m_lppoTfliteModel[0], this->m_lpoInterpreterOptions); - this->m_lppoInterpreter[1] = TfLiteInterpreterCreate(this->m_lppoTfliteModel[1], this->m_lpoInterpreterOptions); + if (tflite_models_[0] == nullptr || tflite_models_[1] == nullptr) break; - if (this->m_lppoInterpreter[0] == nullptr || this->m_lppoInterpreter[1] == nullptr) - break; + // Create option + interpreter_options_ = TfLiteInterpreterOptionsCreate(); + TfLiteInterpreterOptionsSetNumThreads(interpreter_options_, kNumThreads); - //Allocate tensor - if (TfLiteInterpreterAllocateTensors(this->m_lppoInterpreter[0]) != kTfLiteOk) - break; - if (TfLiteInterpreterAllocateTensors(this->m_lppoInterpreter[1]) != kTfLiteOk) - break; + // Create the interpreter + interpreters_[0] = + TfLiteInterpreterCreate(tflite_models_[0], interpreter_options_); + interpreters_[1] = + TfLiteInterpreterCreate(tflite_models_[1], interpreter_options_); - //When use original model - //Input tensor order: - //Model_1[] = {rec, state, ref} - //Model_2[] = {est, state, ref} - //When use quantized models in PiDTLN - //Input tensor order: - //Model_1[] = {rec, ref, state} - //Model_2[] = {ref, state, est} - for (int i = 0; i < k_nNumModels; i++) - { - this->m_lppoInputTensor[i][0] = TfLiteInterpreterGetInputTensor(this->m_lppoInterpreter[i], 0); - this->m_lppoInputTensor[i][1] = TfLiteInterpreterGetInputTensor(this->m_lppoInterpreter[i], 1); - this->m_lppoInputTensor[i][2] = TfLiteInterpreterGetInputTensor(this->m_lppoInterpreter[i], 2); + if (interpreters_[0] == nullptr || interpreters_[1] == nullptr) break; - this->m_lppoOutputTensor[i][0] = TfLiteInterpreterGetOutputTensor(this->m_lppoInterpreter[i], 0); - this->m_lppoOutputTensor[i][1] = TfLiteInterpreterGetOutputTensor(this->m_lppoInterpreter[i], 1); + // Allocate tensor + if (TfLiteInterpreterAllocateTensors(interpreters_[0]) != kTfLiteOk) break; + if (TfLiteInterpreterAllocateTensors(interpreters_[1]) != kTfLiteOk) break; - this->m_lpnStateSize[i] = this->m_lppoInputTensor[i][1]->bytes / sizeof(float); - } + // When use original model + // Input tensor order: + // Model_1[] = {rec, state, ref} + // Model_2[] = {est, state, ref} + // When use quantized models in PiDTLN + // Input tensor order: + // Model_1[] = {rec, ref, state} + // Model_2[] = {ref, state, est} + for (int i = 0; i < kNumModels; i++) { + input_tensors_[i][0] = + TfLiteInterpreterGetInputTensor(interpreters_[i], 0); + input_tensors_[i][1] = + TfLiteInterpreterGetInputTensor(interpreters_[i], 1); + input_tensors_[i][2] = + TfLiteInterpreterGetInputTensor(interpreters_[i], 2); - //RFFT/iRFFT - this->m_lpoFftrCfg = kiss_fftr_alloc(k_nWindowSize, 0, 0, 0); - this->m_lpoIfftrCfg = kiss_fftr_alloc(k_nWindowSize, 1, 0, 0); - - this->m_lpoInputRefCpx = new kiss_fft_cpx[k_nFftForTensorSize]; - this->m_lpoInputRecCpx = new kiss_fft_cpx[k_nFftForTensorSize]; - this->m_lpoOutputCpx = new kiss_fft_cpx[k_nFftForTensorSize]; + output_tensors_[i][0] = + TfLiteInterpreterGetOutputTensor(interpreters_[i], 0); + output_tensors_[i][1] = + TfLiteInterpreterGetOutputTensor(interpreters_[i], 1); - //Internal buffer - this->m_lpfInputRefBuffer = new float[k_nWindowSize]; - this->m_lpfInputRecBuffer = new float[k_nWindowSize]; - this->m_lpfOutputBuffer = new float[k_nWindowSize]; + state_size_[i] = input_tensors_[i][1]->bytes / sizeof(float); + } - memset(this->m_lpfInputRefBuffer, 0, k_nWindowSize * sizeof(float)); - memset(this->m_lpfInputRecBuffer, 0, k_nWindowSize * sizeof(float)); - memset(this->m_lpfOutputBuffer, 0, k_nWindowSize * sizeof(float)); + // RFFT/iRFFT + fftr_cfg_ = kiss_fftr_alloc(kWindowSize, 0, 0, 0); + ifftr_cfg_ = kiss_fftr_alloc(kWindowSize, 1, 0, 0); - this->m_lpfDtlnFreqOutput = new float[k_nFftForTensorSize]; - this->m_lpfDtlnTimeOutput = new float[k_nWindowSize]; + input_ref_cpx_ = new kiss_fft_cpx[kFftForTensorSize]; + input_rec_cpx_ = new kiss_fft_cpx[kFftForTensorSize]; + output_cpx_ = new kiss_fft_cpx[kFftForTensorSize]; - memset(this->m_lpfDtlnFreqOutput, 0, k_nFftForTensorSize * sizeof(float)); - memset(this->m_lpfDtlnTimeOutput, 0, k_nWindowSize * sizeof(float)); + // Internal buffer + input_ref_buffer_ = new float[kWindowSize]; + input_rec_buffer_ = new float[kWindowSize]; + output_buffer_ = new float[kWindowSize]; - this->m_lpfInputRefMag = new float[k_nFftForTensorSize]; - this->m_lpfInputRefPhase = new float[k_nFftForTensorSize]; + memset(input_ref_buffer_, 0, kWindowSize * sizeof(float)); + memset(input_rec_buffer_, 0, kWindowSize * sizeof(float)); + memset(output_buffer_, 0, kWindowSize * sizeof(float)); - memset(this->m_lpfInputRefMag, 0, k_nFftForTensorSize * sizeof(float)); - memset(this->m_lpfInputRefPhase, 0, k_nFftForTensorSize * sizeof(float)); + dtln_freq_output_ = new float[kFftForTensorSize]; + dtln_time_output_ = new float[kWindowSize]; - this->m_lpfInputRecMag = new float[k_nFftForTensorSize]; - this->m_lpfInputRecPhase = new float[k_nFftForTensorSize]; + memset(dtln_freq_output_, 0, kFftForTensorSize * sizeof(float)); + memset(dtln_time_output_, 0, kWindowSize * sizeof(float)); - memset(this->m_lpfInputRecMag, 0, k_nFftForTensorSize * sizeof(float)); - memset(this->m_lpfInputRecPhase, 0, k_nFftForTensorSize * sizeof(float)); + input_ref_mag_ = new float[kFftForTensorSize]; + input_ref_phase_ = new float[kFftForTensorSize]; - this->m_lpfEstimatedBlock = new float[k_nWindowSize]; + memset(input_ref_mag_, 0, kFftForTensorSize * sizeof(float)); + memset(input_ref_phase_, 0, kFftForTensorSize * sizeof(float)); - memset(this->m_lpfEstimatedBlock, 0, k_nWindowSize * sizeof(float)); + input_rec_mag_ = new float[kFftForTensorSize]; + input_rec_phase_ = new float[kFftForTensorSize]; + memset(input_rec_mag_, 0, kFftForTensorSize * sizeof(float)); + memset(input_rec_phase_, 0, kFftForTensorSize * sizeof(float)); - for (int i = 0; i < k_nNumModels; i++) - { - this->m_lppfStates[i] = new float[this->m_lpnStateSize[i]]; - memset(this->m_lppfStates[i], 0, this->m_lpnStateSize[i] * sizeof(float)); - } - - //Format change buffer - this->m_lpfInputRefSample = new float[k_nWindowSize]; - this->m_lpfInputRecSample = new float[k_nWindowSize]; - this->m_lpfOutputSample = new float[k_nWindowSize]; + estimated_block_ = new float[kWindowSize]; - memset(this->m_lpfInputRefSample, 0, k_nWindowSize * sizeof(float)); - memset(this->m_lpfInputRecSample, 0, k_nWindowSize * sizeof(float)); - memset(this->m_lpfOutputSample, 0, k_nWindowSize * sizeof(float)); + memset(estimated_block_, 0, kWindowSize * sizeof(float)); - this->m_bInitSuccess = true; + for (int i = 0; i < kNumModels; i++) { + states_[i] = new float[state_size_[i]]; + memset(states_[i], 0, state_size_[i] * sizeof(float)); + } - nRet = k_nWindowSize; + // Format change buffer + input_ref_sample_ = new float[kWindowSize]; + input_rec_sample_ = new float[kWindowSize]; + output_sample_ = new float[kWindowSize]; - } - while (0); + memset(input_ref_sample_, 0, kWindowSize * sizeof(float)); + memset(input_rec_sample_, 0, kWindowSize * sizeof(float)); + memset(output_sample_, 0, kWindowSize * sizeof(float)); - return nRet; -} + init_success_ = true; + + ret = kWindowSize; + + } while (0); -void DTLN_AEC::m_Impl::Release(void) -{ - //Tensorflow lite - for (int i = 0; i < k_nNumModels; i++) - { - if (this->m_lppoTfliteModel[i] != nullptr) - TfLiteModelDelete(this->m_lppoTfliteModel[i]); + return ret; +} - if (this->m_lppoInterpreter[i] != nullptr) - TfLiteInterpreterDelete(this->m_lppoInterpreter[i]); - } +void DTLN_AEC::Impl::Release() { + // Tensorflow lite + for (int i = 0; i < kNumModels; i++) { + if (tflite_models_[i] != nullptr) TfLiteModelDelete(tflite_models_[i]); - if (this->m_lpoInterpreterOptions != nullptr) - TfLiteInterpreterOptionsDelete(this->m_lpoInterpreterOptions); + if (interpreters_[i] != nullptr) TfLiteInterpreterDelete(interpreters_[i]); + } - //RFFT/iRFFT - if (this->m_lpoFftrCfg != nullptr) - kiss_fft_free(this->m_lpoFftrCfg); + if (interpreter_options_ != nullptr) + TfLiteInterpreterOptionsDelete(interpreter_options_); - if (this->m_lpoIfftrCfg != nullptr) - kiss_fft_free(this->m_lpoIfftrCfg); + // RFFT/iRFFT + if (fftr_cfg_ != nullptr) kiss_fft_free(fftr_cfg_); - if (this->m_lpoInputRefCpx != nullptr) - delete[] this->m_lpoInputRefCpx; + if (ifftr_cfg_ != nullptr) kiss_fft_free(ifftr_cfg_); - if (this->m_lpoInputRecCpx != nullptr) - delete[] this->m_lpoInputRecCpx; + if (input_ref_cpx_ != nullptr) delete[] input_ref_cpx_; - if (this->m_lpoOutputCpx != nullptr) - delete[] this->m_lpoOutputCpx; + if (input_rec_cpx_ != nullptr) delete[] input_rec_cpx_; + if (output_cpx_ != nullptr) delete[] output_cpx_; - //Internal buffer - if (this->m_lpfInputRefBuffer != nullptr) - delete[] this->m_lpfInputRefBuffer; + // Internal buffer + if (input_ref_buffer_ != nullptr) delete[] input_ref_buffer_; - if (this->m_lpfInputRecBuffer != nullptr) - delete[] this->m_lpfInputRecBuffer; + if (input_rec_buffer_ != nullptr) delete[] input_rec_buffer_; - if (this->m_lpfOutputBuffer != nullptr) - delete[] this->m_lpfOutputBuffer; + if (output_buffer_ != nullptr) delete[] output_buffer_; - if (this->m_lpfDtlnFreqOutput != nullptr) - delete[] this->m_lpfDtlnFreqOutput; + if (dtln_freq_output_ != nullptr) delete[] dtln_freq_output_; - if (this->m_lpfDtlnTimeOutput != nullptr) - delete[] this->m_lpfDtlnTimeOutput; + if (dtln_time_output_ != nullptr) delete[] dtln_time_output_; - for (int i = 0; i < k_nNumModels; i++) - { - if (this->m_lppfStates[i] != nullptr) - delete[] this->m_lppfStates[i]; - } + for (int i = 0; i < kNumModels; i++) { + if (states_[i] != nullptr) delete[] states_[i]; + } - if (this->m_lpfInputRefMag != nullptr) - delete[] this->m_lpfInputRefMag; + if (input_ref_mag_ != nullptr) delete[] input_ref_mag_; - if (this->m_lpfInputRefPhase != nullptr) - delete[] this->m_lpfInputRefPhase; + if (input_ref_phase_ != nullptr) delete[] input_ref_phase_; - if (this->m_lpfInputRecMag != nullptr) - delete[] this->m_lpfInputRecMag; + if (input_rec_mag_ != nullptr) delete[] input_rec_mag_; - if (this->m_lpfInputRecPhase != nullptr) - delete[] this->m_lpfInputRecPhase; + if (input_rec_phase_ != nullptr) delete[] input_rec_phase_; - if (this->m_lpfEstimatedBlock != nullptr) - delete[] this->m_lpfEstimatedBlock; + if (estimated_block_ != nullptr) delete[] estimated_block_; - //Format change buffer - if (this->m_lpfInputRefSample != nullptr) - delete[] this->m_lpfInputRefSample; + // Format change buffer + if (input_ref_sample_ != nullptr) delete[] input_ref_sample_; - if (this->m_lpfInputRecSample != nullptr) - delete[] this->m_lpfInputRecSample; + if (input_rec_sample_ != nullptr) delete[] input_rec_sample_; - if (this->m_lpfOutputSample != nullptr) - delete[] this->m_lpfOutputSample; + if (output_sample_ != nullptr) delete[] output_sample_; } -int DTLN_AEC::m_Impl::Process(short *lpsRefBuffer, short *lpsRecBuffer, short *lpsOutputBuffer) -{ - int nRet = -1; +int DTLN_AEC::Impl::Process(short *ref_buffer, short *rec_buffer, + short *output_buffer) { + int ret = -1; - do - { - if (this->m_bInitSuccess == false) - break; + do { + if (init_success_ == false) break; - if (lpsRefBuffer == nullptr || lpsRecBuffer == nullptr || lpsOutputBuffer == nullptr) - break; + if (ref_buffer == nullptr || rec_buffer == nullptr || + output_buffer == nullptr) + break; - //Convert short to float - for (int i = 0; i < k_nWindowSize; i++) - { - this->m_lpfInputRefSample[i] = (float)lpsRefBuffer[i] * 1.0f / SHRT_MAX; - } + // Convert short to float + for (int i = 0; i < kWindowSize; i++) { + input_ref_sample_[i] = (float)ref_buffer[i] * 1.0f / SHRT_MAX; + } - for (int i = 0; i < k_nWindowSize; i++) - { - this->m_lpfInputRecSample[i] = (float)lpsRecBuffer[i] * 1.0f / SHRT_MAX; - } + for (int i = 0; i < kWindowSize; i++) { + input_rec_sample_[i] = (float)rec_buffer[i] * 1.0f / SHRT_MAX; + } - this->AEC(); + AEC(); - //Convert float to short - for (int i = 0; i < k_nWindowSize; i++) - { - lpsOutputBuffer[i] = (short)(this->m_lpfOutputSample[i] * SHRT_MAX); - } + // Convert float to short + for (int i = 0; i < kWindowSize; i++) { + output_buffer[i] = (short)(output_sample_[i] * SHRT_MAX); + } - nRet = 0; - } - while (0); + ret = 0; + } while (0); - return nRet; + return ret; } -void DTLN_AEC::m_Impl::AEC(void) -{ - int nNumBlocks = k_nWindowSize / k_nWindowShift; - - float *pfInputRefSample = this->m_lpfInputRefSample; - float *pfInputRecSample = this->m_lpfInputRecSample; - float *pfOutputSample = this->m_lpfOutputSample; - - for (int i = 0; i < nNumBlocks; i++) - { - //Buffer shift to match FFT size - memmove(this->m_lpfInputRefBuffer, this->m_lpfInputRefBuffer + k_nWindowShift, (k_nWindowSize - k_nWindowShift) * sizeof(float)); - memcpy(this->m_lpfInputRefBuffer + (k_nWindowSize - k_nWindowShift), pfInputRefSample, k_nWindowShift * sizeof(float)); - - memmove(this->m_lpfInputRecBuffer, this->m_lpfInputRecBuffer + k_nWindowShift, (k_nWindowSize - k_nWindowShift) * sizeof(float)); - memcpy(this->m_lpfInputRecBuffer + (k_nWindowSize - k_nWindowShift), pfInputRecSample, k_nWindowShift * sizeof(float)); - - //Prepare buffer - memset(this->m_lpfInputRefMag, 0, k_nFftForTensorSize * sizeof(float)); - memset(this->m_lpfInputRefPhase, 0, k_nFftForTensorSize * sizeof(float)); - - memset(this->m_lpfInputRecMag, 0, k_nFftForTensorSize * sizeof(float)); - memset(this->m_lpfInputRecPhase, 0, k_nFftForTensorSize * sizeof(float)); - - memset(this->m_lpfEstimatedBlock, 0, k_nWindowSize * sizeof(float)); - - //Use RFFT/iRFFT to implement STFT/iSTFT - - //RFFT - kiss_fftr(this->m_lpoFftrCfg, this->m_lpfInputRefBuffer, this->m_lpoInputRefCpx); - kiss_fftr(this->m_lpoFftrCfg, this->m_lpfInputRecBuffer, this->m_lpoInputRecCpx); - - //Calculate Mag/Phase - for (int j = 0; j < k_nFftForTensorSize; j++) - { - //How to calculate Mag/Phase: - //check 3a/3b in https://www.gaussianwaves.com/2015/11/interpreting-fft-results-obtaining-magnitude-and-phase-information/ - this->m_lpfInputRefMag[j] = sqrtf(this->m_lpoInputRefCpx[j].r * this->m_lpoInputRefCpx[j].r + this->m_lpoInputRefCpx[j].i * this->m_lpoInputRefCpx[j].i); - this->m_lpfInputRefPhase[j] = atan2f(this->m_lpoInputRefCpx[j].i, this->m_lpoInputRefCpx[j].r); - - this->m_lpfInputRecMag[j] = sqrtf(this->m_lpoInputRecCpx[j].r * this->m_lpoInputRecCpx[j].r + this->m_lpoInputRecCpx[j].i * this->m_lpoInputRecCpx[j].i); - this->m_lpfInputRecPhase[j] = atan2f(this->m_lpoInputRecCpx[j].i, this->m_lpoInputRecCpx[j].r); - } - - //Set data into tensor - TfLiteTensorCopyFromBuffer(this->m_lppoInputTensor[0][0], this->m_lpfInputRecMag, k_nFftForTensorSize * sizeof(float)); - TfLiteTensorCopyFromBuffer(this->m_lppoInputTensor[0][1], this->m_lppfStates[0], this->m_lpnStateSize[0] * sizeof(float)); - TfLiteTensorCopyFromBuffer(this->m_lppoInputTensor[0][2], this->m_lpfInputRefMag, k_nFftForTensorSize * sizeof(float)); - - //DTLN for freq domain - TfLiteInterpreterInvoke(this->m_lppoInterpreter[0]); - - //Get data from tensor - TfLiteTensorCopyToBuffer(this->m_lppoOutputTensor[0][0], this->m_lpfDtlnFreqOutput, k_nFftForTensorSize * sizeof(float)); - TfLiteTensorCopyToBuffer(this->m_lppoOutputTensor[0][1], this->m_lppfStates[0], this->m_lpnStateSize[0] * sizeof(float)); - - //iRFFT - //this->m_lpfDtlnFreqOutput is out_mask - //Use orignal Mag/Phase to restore generated freq - for (int j = 0; j < k_nFftForTensorSize; j++) - { - //Re{ z } = Re{ a + ib } = Mag * cos[φ] * freq - //Im{ z } = Im{ a + ib } = Mag * sin[φ] * freq - this->m_lpoOutputCpx[j].r = this->m_lpfInputRecMag[j] * cosf(this->m_lpfInputRecPhase[j]) * this->m_lpfDtlnFreqOutput[j]; - this->m_lpoOutputCpx[j].i = this->m_lpfInputRecMag[j] * sinf(this->m_lpfInputRecPhase[j]) * this->m_lpfDtlnFreqOutput[j]; - } - - kiss_fftri(this->m_lpoIfftrCfg, this->m_lpoOutputCpx, this->m_lpfEstimatedBlock); - - //FFT coefficient 1/N - for (int j = 0; j < k_nWindowSize; j++) - this->m_lpfEstimatedBlock[j] = this->m_lpfEstimatedBlock[j] / k_nWindowSize; - - //Set data into tensor - TfLiteTensorCopyFromBuffer(this->m_lppoInputTensor[1][0], this->m_lpfEstimatedBlock, k_nWindowSize * sizeof(float)); - TfLiteTensorCopyFromBuffer(this->m_lppoInputTensor[1][1], this->m_lppfStates[1], this->m_lpnStateSize[1] * sizeof(float)); - TfLiteTensorCopyFromBuffer(this->m_lppoInputTensor[1][2], this->m_lpfInputRefBuffer, k_nWindowSize * sizeof(float)); - - //DTLN for time domain - TfLiteInterpreterInvoke(this->m_lppoInterpreter[1]); - - //Get data from tensor - TfLiteTensorCopyToBuffer(this->m_lppoOutputTensor[1][0], this->m_lpfDtlnTimeOutput, k_nWindowSize * sizeof(float)); - TfLiteTensorCopyToBuffer(this->m_lppoOutputTensor[1][1], this->m_lppfStates[1], this->m_lpnStateSize[1] * sizeof(float)); - - //Overlap add - memmove(this->m_lpfOutputBuffer, this->m_lpfOutputBuffer + k_nWindowShift, (k_nWindowSize - k_nWindowShift) * sizeof(float)); - memset(this->m_lpfOutputBuffer + (k_nWindowSize - k_nWindowShift), 0, k_nWindowShift * sizeof(float)); - - for (int j = 0; j < k_nWindowSize; j++) - this->m_lpfOutputBuffer[j] += this->m_lpfDtlnTimeOutput[j]; - - - memcpy(pfOutputSample, this->m_lpfOutputBuffer, k_nWindowShift * sizeof(float)); - - pfInputRefSample += k_nWindowShift; - pfInputRecSample += k_nWindowShift; - pfOutputSample += k_nWindowShift; - } +void DTLN_AEC::Impl::AEC() { + int num_blocks = kWindowSize / kWindowShift; + + float *input_ref_sample = input_ref_sample_; + float *input_rec_sample = input_rec_sample_; + float *output_sample = output_sample_; + + for (int i = 0; i < num_blocks; i++) { + // Buffer shift to match FFT size + memmove(input_ref_buffer_, input_ref_buffer_ + kWindowShift, + (kWindowSize - kWindowShift) * sizeof(float)); + memcpy(input_ref_buffer_ + (kWindowSize - kWindowShift), input_ref_sample, + kWindowShift * sizeof(float)); + + memmove(input_rec_buffer_, input_rec_buffer_ + kWindowShift, + (kWindowSize - kWindowShift) * sizeof(float)); + memcpy(input_rec_buffer_ + (kWindowSize - kWindowShift), input_rec_sample, + kWindowShift * sizeof(float)); + + // Prepare buffer + memset(input_ref_mag_, 0, kFftForTensorSize * sizeof(float)); + memset(input_ref_phase_, 0, kFftForTensorSize * sizeof(float)); + + memset(input_rec_mag_, 0, kFftForTensorSize * sizeof(float)); + memset(input_rec_phase_, 0, kFftForTensorSize * sizeof(float)); + + memset(estimated_block_, 0, kWindowSize * sizeof(float)); + + // Use RFFT/iRFFT to implement STFT/iSTFT + + // RFFT + kiss_fftr(fftr_cfg_, input_ref_buffer_, input_ref_cpx_); + kiss_fftr(fftr_cfg_, input_rec_buffer_, input_rec_cpx_); + + // Calculate Mag/Phase + for (int j = 0; j < kFftForTensorSize; j++) { + // How to calculate Mag/Phase: + // check 3a/3b in + // https://www.gaussianwaves.com/2015/11/interpreting-fft-results-obtaining-magnitude-and-phase-information/ + input_ref_mag_[j] = sqrtf(input_ref_cpx_[j].r * input_ref_cpx_[j].r + + input_ref_cpx_[j].i * input_ref_cpx_[j].i); + input_ref_phase_[j] = atan2f(input_ref_cpx_[j].i, input_ref_cpx_[j].r); + + input_rec_mag_[j] = sqrtf(input_rec_cpx_[j].r * input_rec_cpx_[j].r + + input_rec_cpx_[j].i * input_rec_cpx_[j].i); + input_rec_phase_[j] = atan2f(input_rec_cpx_[j].i, input_rec_cpx_[j].r); + } + + // Set data into tensor + TfLiteTensorCopyFromBuffer(input_tensors_[0][0], input_rec_mag_, + kFftForTensorSize * sizeof(float)); + TfLiteTensorCopyFromBuffer(input_tensors_[0][1], states_[0], + state_size_[0] * sizeof(float)); + TfLiteTensorCopyFromBuffer(input_tensors_[0][2], input_ref_mag_, + kFftForTensorSize * sizeof(float)); + + // DTLN for freq domain + TfLiteInterpreterInvoke(interpreters_[0]); + + // Get data from tensor + TfLiteTensorCopyToBuffer(output_tensors_[0][0], dtln_freq_output_, + kFftForTensorSize * sizeof(float)); + TfLiteTensorCopyToBuffer(output_tensors_[0][1], states_[0], + state_size_[0] * sizeof(float)); + + // iRFFT + // dtln_freq_output_ is out_mask + // Use orignal Mag/Phase to restore generated freq + for (int j = 0; j < kFftForTensorSize; j++) { + // Re{ z } = Re{ a + ib } = Mag * cos[φ] * freq + // Im{ z } = Im{ a + ib } = Mag * sin[φ] * freq + output_cpx_[j].r = + input_rec_mag_[j] * cosf(input_rec_phase_[j]) * dtln_freq_output_[j]; + output_cpx_[j].i = + input_rec_mag_[j] * sinf(input_rec_phase_[j]) * dtln_freq_output_[j]; + } + + kiss_fftri(ifftr_cfg_, output_cpx_, estimated_block_); + + // FFT coefficient 1/N + for (int j = 0; j < kWindowSize; j++) + estimated_block_[j] = estimated_block_[j] / kWindowSize; + + // Set data into tensor + TfLiteTensorCopyFromBuffer(input_tensors_[1][0], estimated_block_, + kWindowSize * sizeof(float)); + TfLiteTensorCopyFromBuffer(input_tensors_[1][1], states_[1], + state_size_[1] * sizeof(float)); + TfLiteTensorCopyFromBuffer(input_tensors_[1][2], input_ref_buffer_, + kWindowSize * sizeof(float)); + + // DTLN for time domain + TfLiteInterpreterInvoke(interpreters_[1]); + + // Get data from tensor + TfLiteTensorCopyToBuffer(output_tensors_[1][0], dtln_time_output_, + kWindowSize * sizeof(float)); + TfLiteTensorCopyToBuffer(output_tensors_[1][1], states_[1], + state_size_[1] * sizeof(float)); + + // Overlap add + memmove(output_buffer_, output_buffer_ + kWindowShift, + (kWindowSize - kWindowShift) * sizeof(float)); + memset(output_buffer_ + (kWindowSize - kWindowShift), 0, + kWindowShift * sizeof(float)); + + for (int j = 0; j < kWindowSize; j++) + output_buffer_[j] += dtln_time_output_[j]; + + memcpy(output_sample, output_buffer_, kWindowShift * sizeof(float)); + + input_ref_sample += kWindowShift; + input_rec_sample += kWindowShift; + output_sample += kWindowShift; + } } -DTLN_AEC::DTLN_AEC() :m_lpoImpl(new DTLN_AEC::m_Impl) -{ -} +DTLN_AEC::DTLN_AEC() : impl_(new DTLN_AEC::Impl) {} -DTLN_AEC::~DTLN_AEC() -{ - this->m_lpoImpl->Release(); +DTLN_AEC::~DTLN_AEC() { + impl_->Release(); - delete this->m_lpoImpl; - this->m_lpoImpl = nullptr; + delete impl_; + impl_ = nullptr; } -int DTLN_AEC::Init(void) -{ - return this->m_lpoImpl->Init(); -} +int DTLN_AEC::Init() { return impl_->Init(); } -int DTLN_AEC::Process(short *lpsRefBuffer, short *lpsRecBuffer, short *lpsOutputBuffer) -{ - return this->m_lpoImpl->Process(lpsRefBuffer, lpsRecBuffer, lpsOutputBuffer); +int DTLN_AEC::Process(short *ref_buffer, short *rec_buffer, + short *output_buffer) { + return impl_->Process(ref_buffer, rec_buffer, output_buffer); } diff --git a/DTLN_AEC/DTLN_AEC.h b/DTLN_AEC/DTLN_AEC.h index 17d2791..8e12662 100644 --- a/DTLN_AEC/DTLN_AEC.h +++ b/DTLN_AEC/DTLN_AEC.h @@ -1,34 +1,33 @@ +#ifndef DTLN_AEC_DTLN_AEC_H_ +#define DTLN_AEC_DTLN_AEC_H_ #if defined(_WIN32) || defined(_WIN64) #ifdef DTLNAEC_EXPORTS -# define DTLNAEC __declspec(dllexport) +#define DTLNAEC __declspec(dllexport) #else -# define DTLNAEC __declspec(dllimport) +#define DTLNAEC __declspec(dllimport) #endif -//Only support 16K 16Bit Mono PCM - +// Only support 16K 16Bit Mono PCM. class DTLNAEC DTLN_AEC -//Windows win32/x86_64 #else -class DTLN_AEC //#elif defined(__APPLE__) -//macOS +class DTLN_AEC #endif { -public: - DTLN_AEC(); - ~DTLN_AEC(); + public: + DTLN_AEC(); + ~DTLN_AEC(); - //Return number of input samples, -1 = Fail - int Init(void); + // Returns number of input samples, -1 = fail. + int Init(); - //0 = Success, -1 = Fail - int Process(short *lpsRefBuffer, short *lpsRecBuffer, short *lpsOutputBuffer); + // 0 = success, -1 = fail. + int Process(short* ref_buffer, short* rec_buffer, short* output_buffer); -private: - class m_Impl; - m_Impl *m_lpoImpl = nullptr; + private: + class Impl; + Impl* impl_ = nullptr; }; - +#endif // DTLN_AEC_DTLN_AEC_H_