Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include <algorithm>
#include <cstddef>
#include <limits>
#include <math.h>
#include <rnexecutorch/data_processing/FFT.h>
#include <rnexecutorch/data_processing/dsp.h>
Expand All @@ -18,48 +16,4 @@ std::vector<float> hannWindow(size_t size) {
return window;
}

std::vector<float> stftFromWaveform(std::span<const float> waveform,
size_t fftWindowSize, size_t hopSize) {
// Initialize FFT
FFT fft(fftWindowSize);

const auto numFrames = 1 + (waveform.size() - fftWindowSize) / hopSize;
const auto numBins = fftWindowSize / 2;
const auto hann = hannWindow(fftWindowSize);
auto inBuffer = std::vector<float>(fftWindowSize);
auto outBuffer = std::vector<std::complex<float>>(fftWindowSize);

// Output magnitudes in dB
std::vector<float> magnitudes;
magnitudes.reserve(numFrames * numBins);
const auto magnitudeScale = 1.0f / static_cast<float>(fftWindowSize);
constexpr auto epsilon = std::numeric_limits<float>::epsilon();
constexpr auto dbConversionFactor = 20.0f;

for (size_t t = 0; t < numFrames; ++t) {
const size_t offset = t * hopSize;
// Clear the input buffer first
std::ranges::fill(inBuffer, 0.0f);

// Fill frame with windowed signal
const size_t samplesToRead =
std::min(fftWindowSize, waveform.size() - offset);
for (size_t i = 0; i < samplesToRead; i++) {
inBuffer[i] = waveform[offset + i] * hann[i];
}

fft.doFFT(inBuffer.data(), outBuffer);

// Calculate magnitudes in dB (only positive frequencies)
for (size_t i = 0; i < numBins; i++) {
const auto magnitude = std::abs(outBuffer[i]) * magnitudeScale;
const auto magnitude_db =
dbConversionFactor * log10f(magnitude + epsilon);
magnitudes.push_back(magnitude_db);
}
}

return magnitudes;
}

} // namespace rnexecutorch::dsp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ BaseModel::BaseModel(const std::string &modelSource,
}

std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
int32_t index) {
int32_t index) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot get input shape");
}
Expand All @@ -56,7 +56,7 @@ std::vector<int32_t> BaseModel::getInputShape(std::string method_name,
}

std::vector<std::vector<int32_t>>
BaseModel::getAllInputShapes(std::string methodName) {
BaseModel::getAllInputShapes(std::string methodName) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot get all input shapes");
}
Expand Down Expand Up @@ -88,7 +88,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
/// to JS. It is not meant to be used within C++. If you want to call forward
/// from C++ on a BaseModel, please use BaseModel::forward.
std::vector<JSTensorViewOut>
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
}
Expand Down Expand Up @@ -136,7 +136,7 @@ BaseModel::forwardJS(std::vector<JSTensorViewIn> tensorViewVec) {
}

Result<executorch::runtime::MethodMeta>
BaseModel::getMethodMeta(const std::string &methodName) {
BaseModel::getMethodMeta(const std::string &methodName) const {
if (!module_) {
throw std::runtime_error("Model not loaded: Cannot get method meta!");
}
Expand All @@ -161,7 +161,7 @@ BaseModel::forward(const std::vector<EValue> &input_evalues) const {

Result<std::vector<EValue>>
BaseModel::execute(const std::string &methodName,
const std::vector<EValue> &input_value) {
const std::vector<EValue> &input_value) const {
if (!module_) {
throw std::runtime_error("Model not loaded, cannot run execute.");
}
Expand All @@ -175,7 +175,7 @@ std::size_t BaseModel::getMemoryLowerBound() const noexcept {
void BaseModel::unload() noexcept { module_.reset(nullptr); }

std::vector<int32_t>
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) {
BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const {
auto sizes = tensor.sizes();
return std::vector<int32_t>(sizes.begin(), sizes.end());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@ class BaseModel {
Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors);
std::size_t getMemoryLowerBound() const noexcept;
void unload() noexcept;
std::vector<int32_t> getInputShape(std::string method_name, int32_t index);
std::vector<int32_t> getInputShape(std::string method_name,
int32_t index) const;
std::vector<std::vector<int32_t>>
getAllInputShapes(std::string methodName = "forward");
getAllInputShapes(std::string methodName = "forward") const;
std::vector<JSTensorViewOut>
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
forwardJS(std::vector<JSTensorViewIn> tensorViewVec) const;
Result<std::vector<EValue>> forward(const EValue &input_value) const;
Result<std::vector<EValue>>
forward(const std::vector<EValue> &input_value) const;
Result<std::vector<EValue>> execute(const std::string &methodName,
const std::vector<EValue> &input_value);
Result<std::vector<EValue>>
execute(const std::string &methodName,
const std::vector<EValue> &input_value) const;
Result<executorch::runtime::MethodMeta>
getMethodMeta(const std::string &methodName);
getMethodMeta(const std::string &methodName) const;

protected:
// If possible, models should not use the JS runtime to keep JSI internals
Expand All @@ -49,7 +51,8 @@ class BaseModel {
std::size_t memorySizeLowerBound{0};

private:
std::vector<int32_t> getTensorShape(const executorch::aten::Tensor &tensor);
std::vector<int32_t>
getTensorShape(const executorch::aten::Tensor &tensor) const;
};
} // namespace models

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "ASR.h"
#include "executorch/extension/tensor/tensor_ptr.h"
#include "rnexecutorch/data_processing/Numerical.h"
#include "rnexecutorch/data_processing/dsp.h"
#include "rnexecutorch/data_processing/gzip.h"

namespace rnexecutorch::models::speech_to_text::asr {
Expand Down Expand Up @@ -37,8 +36,7 @@ ASR::getInitialSequence(const DecodingOptions &options) const {
return seq;
}

GenerationResult ASR::generate(std::span<const float> waveform,
float temperature,
GenerationResult ASR::generate(std::span<float> waveform, float temperature,
const DecodingOptions &options) const {
std::vector<float> encoderOutput = this->encode(waveform);

Expand Down Expand Up @@ -94,7 +92,7 @@ float ASR::getCompressionRatio(const std::string &text) const {
}

std::vector<Segment>
ASR::generateWithFallback(std::span<const float> waveform,
ASR::generateWithFallback(std::span<float> waveform,
const DecodingOptions &options) const {
std::vector<float> temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
std::vector<int32_t> bestTokens;
Expand Down Expand Up @@ -209,7 +207,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span<const int32_t> tokens,
return wordObjs;
}

std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
std::vector<Segment> ASR::transcribe(std::span<float> waveform,
const DecodingOptions &options) const {
int32_t seek = 0;
std::vector<Segment> results;
Expand All @@ -218,7 +216,7 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
int32_t start = seek * ASR::kSamplingRate;
const auto end = std::min<int32_t>(
(seek + ASR::kChunkSize) * ASR::kSamplingRate, waveform.size());
std::span<const float> chunk = waveform.subspan(start, end - start);
auto chunk = waveform.subspan(start, end - start);

if (std::cmp_less(chunk.size(), ASR::kMinChunkSamples)) {
break;
Expand Down Expand Up @@ -246,19 +244,12 @@ std::vector<Segment> ASR::transcribe(std::span<const float> waveform,
return results;
}

std::vector<float> ASR::encode(std::span<const float> waveform) const {
constexpr int32_t fftWindowSize = 512;
constexpr int32_t stftHopLength = 160;
constexpr int32_t innerDim = 256;

std::vector<float> preprocessedData =
dsp::stftFromWaveform(waveform, fftWindowSize, stftHopLength);
const auto numFrames =
static_cast<int32_t>(preprocessedData.size()) / innerDim;
std::vector<int32_t> inputShape = {numFrames, innerDim};
std::vector<float> ASR::encode(std::span<float> waveform) const {
auto inputShape = {static_cast<int32_t>(waveform.size())};

const auto modelInputTensor = executorch::extension::make_tensor_ptr(
std::move(inputShape), std::move(preprocessedData));
std::move(inputShape), waveform.data(),
executorch::runtime::etensor::ScalarType::Float);
const auto encoderResult = this->encoder->forward(modelInputTensor);

if (!encoderResult.ok()) {
Expand All @@ -268,7 +259,7 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
}

const auto decoderOutputTensor = encoderResult.get().at(0).toTensor();
const int32_t outputNumel = decoderOutputTensor.numel();
const auto outputNumel = decoderOutputTensor.numel();

const float *const dataPtr = decoderOutputTensor.const_data_ptr<float>();
return {dataPtr, dataPtr + outputNumel};
Expand All @@ -277,8 +268,10 @@ std::vector<float> ASR::encode(std::span<const float> waveform) const {
std::vector<float> ASR::decode(std::span<int32_t> tokens,
std::span<float> encoderOutput) const {
std::vector<int32_t> tokenShape = {1, static_cast<int32_t>(tokens.size())};
auto tokensLong = std::vector<int64_t>(tokens.begin(), tokens.end());

auto tokenTensor = executorch::extension::make_tensor_ptr(
std::move(tokenShape), tokens.data(), ScalarType::Int);
tokenShape, tokensLong.data(), ScalarType::Long);

const auto encoderOutputSize = static_cast<int32_t>(encoderOutput.size());
std::vector<int32_t> encShape = {1, ASR::kNumFrames,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class ASR {
const models::BaseModel *decoder,
const TokenizerModule *tokenizer);
std::vector<types::Segment>
transcribe(std::span<const float> waveform,
transcribe(std::span<float> waveform,
const types::DecodingOptions &options) const;
std::vector<float> encode(std::span<const float> waveform) const;
std::vector<float> encode(std::span<float> waveform) const;
std::vector<float> decode(std::span<int32_t> tokens,
std::span<float> encoderOutput) const;

Expand Down Expand Up @@ -44,11 +44,10 @@ class ASR {

std::vector<int32_t>
getInitialSequence(const types::DecodingOptions &options) const;
types::GenerationResult generate(std::span<const float> waveform,
float temperature,
types::GenerationResult generate(std::span<float> waveform, float temperature,
const types::DecodingOptions &options) const;
std::vector<types::Segment>
generateWithFallback(std::span<const float> waveform,
generateWithFallback(std::span<float> waveform,
const types::DecodingOptions &options) const;
std::vector<types::Segment>
calculateWordLevelTimestamps(std::span<const int32_t> tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <array>
#include <functional>
#include <numeric>
#include <ranges>
#include <vector>

namespace rnexecutorch::models::voice_activity_detection {
Expand Down Expand Up @@ -158,4 +157,4 @@ VoiceActivityDetection::postprocess(const std::vector<float> &scores,
return speechSegments;
}

} // namespace rnexecutorch::models::voice_activity_detection
} // namespace rnexecutorch::models::voice_activity_detection
46 changes: 28 additions & 18 deletions packages/react-native-executorch/src/constants/modelUrls.ts
Original file line number Diff line number Diff line change
Expand Up @@ -307,29 +307,32 @@ export const STYLE_TRANSFER_UDNIE = {
};

// S2T
const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`;
const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;
const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/tokenizer.json`;
const WHISPER_TINY_EN_ENCODER = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_en_encoder_xnnpack.pte`;
const WHISPER_TINY_EN_DECODER = `${URL_PREFIX}-whisper-tiny.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_en_decoder_xnnpack.pte`;

const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/tokenizer.json`;
const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`;
const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`;
const WHISPER_TINY_EN_ENCODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_encoder_xnnpack.pte`;
const WHISPER_TINY_EN_DECODER_QUANTIZED = `${URL_PREFIX}-whisper-tiny-quantized.en/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_quantized_en_decoder_xnnpack.pte`;

const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/tokenizer.json`;
const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`;
const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`;
const WHISPER_BASE_EN_TOKENIZER = `${URL_PREFIX}-whisper-base.en/${NEXT_VERSION_TAG}/tokenizer.json`;
const WHISPER_BASE_EN_ENCODER = `${URL_PREFIX}-whisper-base.en/${NEXT_VERSION_TAG}/xnnpack/whisper_base_en_encoder_xnnpack.pte`;
const WHISPER_BASE_EN_DECODER = `${URL_PREFIX}-whisper-base.en/${NEXT_VERSION_TAG}/xnnpack/whisper_base_en_decoder_xnnpack.pte`;

const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`;
const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`;
const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`;
const WHISPER_SMALL_EN_TOKENIZER = `${URL_PREFIX}-whisper-small.en/${NEXT_VERSION_TAG}/tokenizer.json`;
const WHISPER_SMALL_EN_ENCODER = `${URL_PREFIX}-whisper-small.en/${NEXT_VERSION_TAG}/xnnpack/whisper_small_en_encoder_xnnpack.pte`;
const WHISPER_SMALL_EN_DECODER = `${URL_PREFIX}-whisper-small.en/${NEXT_VERSION_TAG}/xnnpack/whisper_small_en_decoder_xnnpack.pte`;

const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/tokenizer.json`;
const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`;
const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`;
const WHISPER_TINY_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${NEXT_VERSION_TAG}/tokenizer.json`;
const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_encoder_xnnpack.pte`;
const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${NEXT_VERSION_TAG}/xnnpack/whisper_tiny_decoder_xnnpack.pte`;

const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/tokenizer.json`;
const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`;
const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`;
const WHISPER_BASE_TOKENIZER = `${URL_PREFIX}-whisper-base/${NEXT_VERSION_TAG}/tokenizer.json`;
const WHISPER_BASE_ENCODER_MODEL = `${URL_PREFIX}-whisper-base/${NEXT_VERSION_TAG}/xnnpack/whisper_base_encoder_xnnpack.pte`;
const WHISPER_BASE_DECODER_MODEL = `${URL_PREFIX}-whisper-base/${NEXT_VERSION_TAG}/xnnpack/whisper_base_decoder_xnnpack.pte`;

const WHISPER_SMALL_TOKENIZER = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/tokenizer.json`;
const WHISPER_SMALL_ENCODER_MODEL = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/xnnpack/whisper_small_encoder_xnnpack.pte`;
const WHISPER_SMALL_DECODER_MODEL = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/xnnpack/whisper_small_decoder_xnnpack.pte`;

export const WHISPER_TINY_EN = {
isMultilingual: false,
Expand All @@ -338,6 +341,13 @@ export const WHISPER_TINY_EN = {
tokenizerSource: WHISPER_TINY_EN_TOKENIZER,
};

export const WHISPER_TINY_EN_QUANTIZED = {
isMultilingual: false,
encoderSource: WHISPER_TINY_EN_ENCODER_QUANTIZED,
decoderSource: WHISPER_TINY_EN_DECODER_QUANTIZED,
tokenizerSource: WHISPER_TINY_EN_TOKENIZER,
};

export const WHISPER_BASE_EN = {
isMultilingual: false,
encoderSource: WHISPER_BASE_EN_ENCODER,
Expand Down
Loading