Skip to content

Commit 2257b5b

Browse files
committed
DRAFT STT
1 parent 8f8f1d8 commit 2257b5b

File tree

3 files changed

+148
-22
lines changed

3 files changed

+148
-22
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#pragma once
2+
3+
#include "AudioTools/CoreAudio/AudioStreams.h"
4+
#include "AudioTools/CoreAudio/Buffers.h"
5+
6+
namespace audio_tools {
7+
/**
8+
* @class EchoCancellation
9+
* @brief Echo cancellation with adaptive LMS filtering for microcontrollers.
10+
*
11+
* This class implements echo cancellation using an adaptive FIR filter (LMS
12+
* algorithm). It estimates the echo path and subtracts the estimated echo from
13+
* the microphone input.
14+
*/
15+
template <typename T = int16_t>
16+
class EchoCancellation : public AudioStream {
17+
public:
18+
/**
19+
* @brief Constructor
20+
* @param in Reference to the input stream (microphone or audio input)
21+
* @param lag_samples Number of samples to delay the echo subtraction
22+
* (default: 0)
23+
* @param buffer_size Size of the internal ring buffer (default: 512)
24+
*/
25+
EchoCancellation(Stream& in, size_t lag_samples = 0, size_t buffer_size = 512,
26+
size_t filter_len = 32, float mu = 0.001f)
27+
: lag(lag_samples),
28+
buffer_size(buffer_size),
29+
filter_len(filter_len),
30+
adaptation_rate(mu) {
31+
p_io = &in;
32+
filter.resize(filter_len, 0.0f);
33+
reset();
34+
}
35+
36+
/**
37+
* @brief Store the output signal (sent to speaker)
38+
* @param buf Pointer to PCM data sent to the speaker (T*)
39+
* @param len Number of bytes in buf
40+
* @return Number of bytes processed
41+
*/
42+
size_t write(const uint8_t* buf, size_t len) override {
43+
// Store output signal in queue for echo estimation
44+
return ring_buffer.writeArray((T*)buf, len / sizeof(T)) *
45+
sizeof(T);
46+
}
47+
48+
/**
49+
* @brief Read input and remove echo (subtract output signal with lag)
50+
* @param buf Pointer to buffer to store processed input (T*)
51+
* @param len Number of bytes to read
52+
* @return Number of bytes read from input
53+
*/
54+
size_t readBytes(uint8_t* buf, size_t len) override {
55+
size_t read = p_io->readBytes(buf, len);
56+
size_t actual_samples = read / sizeof(T);
57+
T* data = (T*)buf;
58+
Vector<T> ref_vec(filter_len, 0);
59+
ring_buffer.peekArray(ref_vec.data(), filter_len);
60+
for (size_t i = 0; i < actual_samples; ++i) {
61+
// Build the reference vector for the adaptive filter
62+
float echo_est = 0.0f;
63+
for (size_t k = 0; k < filter_len; ++k) {
64+
echo_est += filter[k] * ref_vec[k];
65+
}
66+
float mic = (float)data[i];
67+
float error = mic - echo_est;
68+
data[i] = (T)error;
69+
// LMS update
70+
for (size_t k = 0; k < filter_len; ++k) {
71+
filter[k] += adaptation_rate * error * ref_vec[k];
72+
}
73+
T dummy;
74+
ring_buffer.read(dummy); // Advance the queue
75+
// Shift ref_vec left and append dummy
76+
for (size_t k = 0; k < filter_len - 1; ++k) {
77+
ref_vec[k] = ref_vec[k + 1];
78+
}
79+
ref_vec[filter_len - 1] = dummy;
80+
}
81+
return read;
82+
}
83+
84+
/**
85+
* @brief Set the lag (delay) in samples for echo cancellation.
86+
* @param lag_samples Number of samples to delay the echo subtraction
87+
*/
88+
void setLag(size_t lag_samples) { lag = lag_samples; }
89+
90+
/**
91+
* @brief Set the adaptation rate (mu) for the LMS algorithm.
92+
* @param mu Adaptation rate
93+
*/
94+
void setMu(float mu) { adaptation_rate = mu; }
95+
96+
/**
97+
* @brief Set the filter length for the adaptive filter.
98+
* @param len Length of the adaptive filter
99+
*/
100+
void setFilterLen(size_t len) {
101+
filter_len = len;
102+
filter.resize(filter_len, 0.0f);
103+
}
104+
105+
/**
106+
* @brief Reset the internal buffer and lag state.
107+
*/
108+
void reset() {
109+
ring_buffer.resize(buffer_size + lag);
110+
for (size_t j = 0; j < lag; j++) {
111+
ring_buffer.write(0);
112+
}
113+
filter.assign(filter_len, 0.0f);
114+
}
115+
116+
protected:
117+
Stream* p_io = nullptr;
118+
RingBuffer<T> ring_buffer{0};
119+
size_t buffer_size;
120+
size_t lag; // lag in samples
121+
// Adaptive filter
122+
size_t filter_len;
123+
float adaptation_rate = 0.01f;
124+
Vector<float> filter;
125+
};
126+
127+
} // namespace audio_tools

src/AudioTools/STT/EchoCanellation.h

Whitespace-only changes.

src/AudioTools/STT/WakeWordDetector.h

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
namespace audio_tools {
2-
31
#pragma once
42

53
#include <algorithm>
64
#include <cmath>
75

8-
#include "AudioOutput.h"
9-
#include "Vector.h"
6+
#include "AudioTools/CoreAudio/AudioOutput.h"
7+
#include "AudioTools/CoreAudio/AudioBasic/Collections/Vector.h"
8+
#include "AudioTools/CoreAudio/Buffers.h"
9+
#include "AudioTools/AudioLibs/AudioFFT.h"
1010

1111
namespace audio_tools {
1212

@@ -41,12 +41,13 @@ struct FrequencyFrame {
4141
*
4242
* Example:
4343
* @code
44-
* audio_tools::WakeWordDetector<3> detector(fft, fft_size, frame_size);
45-
*detector.addTemplate(my_template_frames, 80.0f, "hello");
46-
*detector.setWakeWordCallback([](const char* name) { Serial.println(name); });
47-
... (file header and includes)
48-
*/
49-
template <size_t N = 3>
44+
* audio_tools::WakeWordDetector<3> detector(fft);
45+
* detector.addTemplate(my_template_frames, 80.0f, "hello");
46+
* detector.setWakeWordCallback([](const char* name) { Serial.println(name); });
47+
* // ...
48+
* @endcode
49+
*/
50+
template <typename T = int16_t, size_t N = 3>
5051
class WakeWordDetector : public AudioOutput {
5152
public:
5253
struct Template {
@@ -61,12 +62,12 @@ class WakeWordDetector : public AudioOutput {
6162

6263
using WakeWordCallback = void (*)(const char* name);
6364

64-
WakeWordDetector(AudioFFTBase& fft, size_t fft_size, size_t frame_size)
65-
: _fft_size(fft_size), _frame_size(frame_size), p_fft(&fft) {
66-
_buffer.resize(_frame_size, 0);
65+
WakeWordDetector(AudioFFTBase& fft)
66+
: p_fft(&fft) {
6767
_frame_pos = 0;
68-
fft.config().ref = this;
69-
fft.callback = fftResult;
68+
auto& fft_cfg = fft.config();
69+
fft_cfg.ref = this;
70+
fft_cfg.callback = fftResult;
7071
}
7172

7273
void startRecording() {
@@ -94,17 +95,17 @@ class WakeWordDetector : public AudioOutput {
9495

9596
void setWakeWordCallback(WakeWordCallback cb) { _callback = cb; }
9697

97-
size_t write(const void* buf, size_t count) override {
98-
return p_fft->write((const uint8_t*)buf, count);
98+
size_t write(const uint8_t* buf, size_t size) override {
99+
return p_fft->write(buf, size);
99100
}
100101

101102
static void fftResult(AudioFFTBase& fft) {
102103
// This static method must access instance data via fft.config().ref
103-
auto* self = static_cast<WakeWordDetector<N>*>(fft.config().ref);
104+
auto* self = static_cast<WakeWordDetector<T,N>*>(fft.config().ref);
104105
if (!self) return;
105106
FrequencyFrame<N> frame;
106107
AudioFFTResult result[N];
107-
self->p_fft->resultArray(result, N);
108+
fft.resultArray(result);
108109
for (size_t j = 0; j < N; j++) {
109110
frame.top_freqs[j] = result[j].frequency;
110111
}
@@ -130,11 +131,9 @@ class WakeWordDetector : public AudioOutput {
130131
protected:
131132
Vector<Template> _templates; ///< List of wake word templates
132133
Vector<FrequencyFrame<N>> _recent_frames; ///< Recent frames for comparison
133-
Vector<int16_t> _buffer; ///< Buffer for incoming PCM samples
134+
Vector<T> _buffer; ///< Buffer for incoming PCM samples
134135
AudioFFTBase* p_fft = nullptr;
135136
bool _is_recording = false; ///< True if currently recording a template
136-
size_t _fft_size; ///< FFT size per frame
137-
size_t _frame_size; ///< Number of PCM samples per frame
138137
size_t _frame_pos; ///< Current position in frame buffer
139138
size_t _max_template_len = 0; ///< Length of the longest template
140139
WakeWordCallback _callback = nullptr;

0 commit comments

Comments
 (0)