diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..69826a8 Binary files /dev/null and b/.DS_Store differ diff --git a/configs/config.yml b/configs/config.yml index 9004160..517fbf5 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -8,6 +8,8 @@ speech_config: normalize_signal: True normalize_feature: True normalize_per_feature: False + use_fma: True + use_neon: False model_config: name: acrnn @@ -16,20 +18,26 @@ model_config: kernel_size: [[11,5],[11,5],[11,5]] rnn_cell: 256 seq_mask: True + num_languages: 100 dataset_config: vocabulary: vocab/vocab.txt data_path: ./data/wavs/ - corpus_name: ./data/demo_txt/demo + corpus_name: ./data/multilingual/ + fleurs_path: ./data/fleurs/ file_nums: 1 max_audio_length: 2000 shuffle_size: 1200 data_length: None suffix: .txt - load_type: txt + load_type: multilingual train: train - dev: dev + dev: validation test: test + languages_file: configs/languages.json + max_samples_per_language: 10000 + audio_format: wav + metadata_format: json optimizer_config: init_steps: 0 @@ -38,12 +46,15 @@ optimizer_config: beta1: 0.9 beta2: 0.999 epsilon: 1e-9 + use_mixed_precision: True running_config: - prefetch: False - load_weights: ./saved_weights/20230228-084356/last/model + prefetch: True + load_weights: ./saved_weights/multilingual/last/model num_epochs: 100 - batch_size: 1 - train_steps: 50 - dev_steps: 10 - test_steps: 10 \ No newline at end of file + batch_size: 32 + train_steps: 1000 + dev_steps: 100 + test_steps: 100 + save_interval: 5 + eval_interval: 1 \ No newline at end of file diff --git a/configs/languages.json b/configs/languages.json new file mode 100644 index 0000000..5e737e1 --- /dev/null +++ b/configs/languages.json @@ -0,0 +1,18 @@ +{ + "supported_languages": [ + "be_by", + "bg_bg", + "bs_ba", + "ca_cs", + "cs_cz", + "cy_gb" + ], + "language_names": { + "be_by": "Belarusian", + "bg_bg": "Bulgarian", + "bs_ba": "Bosnian", + "ca_cs": "Catalan", + "cs_cz": "Czech", + "cy_gb": "Welsh" + } +} \ No newline at end of file diff --git a/convert_to_pb.py b/convert_to_pb.py index 39f6acf..17d4705 100644 --- a/convert_to_pb.py +++ b/convert_to_pb.py @@ -20,25 +20,38 @@ vocab = Vocab(vocabulary) -# build model -model=Model(**config.model_config,vocab_size=len(vocab.token_list)) +# Build model +model = Model(**config.model_config, vocab_size=len(vocab.token_list)) model.init_build([None, config.speech_config['num_feature_bins']]) model.load_weights(weights_dir + "last/model") model.add_featurizers(speech_featurizer) - version = 2 -#****convert to pb****** -tf.saved_model.save(model, "saved_models/lang14/pb/" + str(version)) -print('convert to pb model successful') -#****convert to serving****** +# Convert to SavedModel format with signatures +@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)]) +def predict_fn(signal): + output, prob = model.predict_pb(signal) + return {"output_0": output, "output_1": prob} + +# Save model with proper signatures tf.saved_model.save( model, - "./saved_models/lang14/serving/"+str(version), + f"saved_models/lang14/pb/{version}", signatures={ - 'predict_pb': model.predict_pb - } + "serving_default": predict_fn, + "predict_pb": model.predict_pb + } ) +print('Model converted to SavedModel format successfully') -print('convert to serving model successful') +# Save model for TensorFlow Serving +tf.saved_model.save( + model, + f"saved_models/lang14/serving/{version}", + signatures={ + "serving_default": predict_fn, + "predict_pb": model.predict_pb + } +) +print('Model converted for TensorFlow Serving successfully') diff --git a/download_fleurs.py b/download_fleurs.py new file mode 100644 index 0000000..7ff0280 --- /dev/null +++ b/download_fleurs.py @@ -0,0 +1,184 @@ +import os +import json +import argparse +import shutil +import time +from tqdm import tqdm +from datasets import load_dataset, get_dataset_config_names +import soundfile as sf +import numpy as np +from pathlib import Path + +# All FLEURS languages +ALL_LANGUAGES = [ + 'af', 'am', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', + 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fr', 'ga', 'gl', 'gu', 'ha', 'he', + 'hi', 'hr', 'hu', 'hy', 'id', 'ig', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', + 'ko', 'ky', 'lb', 'lg', 'ln', 'lo', 'lt', 'lv', 'mg', 'mk', 'ml', 'mn', 'mr', 'ms', + 'my', 'ne', 'nl', 'no', 'ny', 'or', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'rw', 'sd', + 'si', 'sk', 'sl', 'sn', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', + 'tk', 'tr', 'uk', 'ur', 'uz', 'vi', 'wo', 'xh', 'yi', 'yo', 'zh', 'zu' +] + +def ensure_dir(path): + """Create directory if it doesn't exist""" + Path(path).mkdir(parents=True, exist_ok=True) + +def save_audio(audio_data, sample_rate, output_path): + """Save audio data to WAV file""" + sf.write(output_path, audio_data, sample_rate) + +def download_language(lang, output_dir, splits=None, retry_count=3, retry_delay=5): + """Download and organize dataset for a specific language with retries""" + if splits is None: + splits = ['train', 'validation', 'test'] + + lang_dir = os.path.join(output_dir, lang) + print(f"\nProcessing language: {lang}") + + for split in splits: + print(f"\nDownloading {split} split...") + split_dir = os.path.join(lang_dir, split) + audio_dir = os.path.join(split_dir, 'audio') + + # Skip if already downloaded + metadata_path = os.path.join(split_dir, 'metadata.json') + if os.path.exists(metadata_path): + print(f"Skipping {lang} {split} - already downloaded") + continue + + ensure_dir(audio_dir) + + # Load dataset with retries + dataset = None + for attempt in range(retry_count): + try: + dataset = load_dataset("google/fleurs", lang, split=split) + break + except Exception as e: + if attempt < retry_count - 1: + print(f"Attempt {attempt + 1} failed for {lang} {split}: {str(e)}") + print(f"Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + else: + print(f"Error downloading {lang} {split} after {retry_count} attempts: {str(e)}") + return False + + if dataset is None: + continue + + # Prepare metadata + metadata = { + 'data': [], + 'lang': lang, + 'split': split + } + + # Process each example + for idx, item in enumerate(tqdm(dataset, desc=f"Processing {split}")): + try: + # Extract audio + audio_data = item['audio']['array'] + sample_rate = item['audio']['sampling_rate'] + + # Generate ID + item_id = f"{lang}_{split}_{idx:06d}" + + # Save audio file + audio_path = os.path.join(audio_dir, f"{item_id}.wav") + save_audio(audio_data, sample_rate, audio_path) + + # Add to metadata + metadata['data'].append({ + 'id': item_id, + 'transcription': item.get('transcription', ''), + 'raw_transcription': item.get('raw_transcription', ''), + 'language': item.get('language', lang), + 'gender': item.get('gender', ''), + 'lang_id': item.get('lang_id', -1) + }) + + except Exception as e: + print(f"Error processing item {idx} in {lang} {split}: {str(e)}") + continue + + # Save metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) + + print(f"Saved {len(metadata['data'])} examples for {lang} {split}") + + return True + +def download_languages_in_batches(languages, output_dir, batch_size=5, splits=None): + """Download languages in batches to manage memory usage""" + total_languages = len(languages) + successful = [] + failed = [] + + for i in range(0, total_languages, batch_size): + batch = languages[i:i + batch_size] + print(f"\nProcessing batch {i//batch_size + 1} of {(total_languages + batch_size - 1)//batch_size}") + print(f"Languages in this batch: {', '.join(batch)}") + + for lang in batch: + try: + if download_language(lang, output_dir, splits): + successful.append(lang) + else: + failed.append(lang) + except Exception as e: + print(f"Failed to download {lang}: {str(e)}") + failed.append(lang) + + # Clear some memory + if i + batch_size < total_languages: + print("\nClearing memory before next batch...") + time.sleep(5) # Give some time for memory cleanup + + return successful, failed + +def main(): + parser = argparse.ArgumentParser(description='Download and organize FLEURS dataset') + parser.add_argument('--output_dir', type=str, default='./data/fleurs', + help='Output directory for the dataset') + parser.add_argument('--languages', type=str, nargs='+', + help='List of language codes to download (default: all languages)') + parser.add_argument('--splits', type=str, nargs='+', + default=['train', 'validation', 'test'], + help='Dataset splits to download') + parser.add_argument('--batch_size', type=int, default=5, + help='Number of languages to download in parallel') + args = parser.parse_args() + + # Use all languages if none specified + languages = args.languages if args.languages else ALL_LANGUAGES + + # Create output directory + ensure_dir(args.output_dir) + + # Download languages in batches + print(f"Starting download of {len(languages)} languages in batches of {args.batch_size}") + successful, failed = download_languages_in_batches( + languages, args.output_dir, args.batch_size, args.splits + ) + + # Print summary + print("\n=== Download Summary ===") + print(f"Successfully downloaded: {len(successful)} languages") + print(f"Failed to download: {len(failed)} languages") + + if failed: + print("\nFailed languages:") + print(", ".join(failed)) + + # Save failed languages to file for retry + failed_file = os.path.join(args.output_dir, "failed_languages.txt") + with open(failed_file, 'w') as f: + f.write("\n".join(failed)) + print(f"\nFailed languages list saved to: {failed_file}") + print("You can retry failed languages using:") + print(f"python download_fleurs.py --languages {' '.join(failed)}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/featurizers/speech_featurizers.py b/featurizers/speech_featurizers.py index fb8f83b..aca7324 100644 --- a/featurizers/speech_featurizers.py +++ b/featurizers/speech_featurizers.py @@ -232,10 +232,20 @@ def shape(self) -> list: return [None, self.num_feature_bins, channel_dim] def stft(self, signal): + if len(signal) < self.nfft: + print(f"[Skip] Signal too short for STFT: len({len(signal)}) < nfft = {self.nfft}") + return np.zeros((self.nfft//2 + 1 ,1)) + max_len = 320000 # Increased from 160000 to match other parts of the code + if len(signal) > max_len: + print(f"[Truncate] Signal too long for STFT: len({len(signal)}) > max_len = {max_len}") + # Take the center portion of the signal + start = (len(signal) - max_len) // 2 + signal = signal[start:start + max_len] return np.square( np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step, win_length=self.frame_length, center=True, window="hann"))) + def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0): return librosa.power_to_db(S, ref=ref, amin=amin, top_db=top_db) @@ -309,15 +319,46 @@ def compute_mfcc(self, signal: np.ndarray) -> np.ndarray: return mfcc.T def compute_log_mel_spectrogram(self, signal: np.ndarray) -> np.ndarray: - S = self.stft(signal) - - mel = librosa.filters.mel(self.sample_rate, self.nfft, - n_mels=self.num_feature_bins, - fmin=0.0, fmax=int(self.sample_rate / 2)) - - mel_spectrogram = np.dot(S.T, mel.T) - - return self.power_to_db(mel_spectrogram) + """Compute log mel spectrogram with proper error handling for long signals""" + try: + # Handle long signals - using the same max_len as stft + max_len = 320000 # Maximum length for STFT + if len(signal) > max_len: + print(f"[Truncate] Signal too long: len({len(signal)}) > max_len = {max_len}") + # Take the center portion + start = (len(signal) - max_len) // 2 + signal = signal[start:start + max_len] + + # Compute STFT + S = self.stft(signal) # stft will handle any remaining length issues + + # Create mel filterbank if not already created + if self.mel_filter is None: + self.mel_filter = librosa.filters.mel( + sr=self.sample_rate, + n_fft=self.nfft, + n_mels=self.num_feature_bins, + fmin=0.0, + fmax=int(self.sample_rate / 2) + ) + + # Apply mel filterbank + mel_spectrogram = np.dot(S.T, self.mel_filter.T) + + # Convert to log scale + log_mel_spec = self.power_to_db(mel_spectrogram) + + # Handle any NaN or Inf values + if np.isnan(log_mel_spec).any() or np.isinf(log_mel_spec).any(): + print("Warning: NaN or Inf values in log mel spectrogram, replacing with zeros") + log_mel_spec = np.nan_to_num(log_mel_spec, 0) + + return log_mel_spec + + except Exception as e: + print(f"Error computing log mel spectrogram: {str(e)}") + # Return empty spectrogram with correct shape + return np.zeros((1, self.num_feature_bins)) def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: S = self.stft(signal) @@ -410,21 +451,6 @@ def tf_extract(self, signal: tf.Tensor) -> tf.Tensor: return features - def compute_log_mel_spectrogram(self, signal): - spectrogram = self.stft(signal) - if self.mel_filter is None: - linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix( - num_mel_bins=self.num_feature_bins, - num_spectrogram_bins=spectrogram.shape[-1], - sample_rate=self.sample_rate, - lower_edge_hertz=0.0, upper_edge_hertz=(self.sample_rate / 2) - ) - else: - linear_to_weight_matrix = self.mel_filter - - mel_spectrogram = tf.tensordot(spectrogram, linear_to_weight_matrix, 1) - return self.power_to_db(mel_spectrogram) - def compute_spectrogram(self, signal): S = self.stft(signal) spectrogram = self.power_to_db(S) diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000..124a457 Binary files /dev/null and b/models/.DS_Store differ diff --git a/multilingual_dataset.py b/multilingual_dataset.py new file mode 100644 index 0000000..d9eec3d --- /dev/null +++ b/multilingual_dataset.py @@ -0,0 +1,312 @@ +import os +import json +import numpy as np +import pandas as pd +from tqdm import tqdm +from typing import List, Dict, Optional +from datasets import load_dataset, Dataset, load_from_disk, concatenate_datasets +from huggingface_hub import HfFileSystem +from featurizers.speech_featurizers import NumpySpeechFeaturizer +from configs.config import Config +from vocab.vocab import Vocab + +class MultilingualDataset: + def __init__( + self, + config: Config, + languages: List[str], + vocab: Vocab, + speech_featurizer: NumpySpeechFeaturizer, + data_type: str = "train", + max_samples_per_language: Optional[int] = None + ): + self.config = config + self.languages = languages + self.vocab = vocab + self.speech_featurizer = speech_featurizer + self.data_type = data_type + self.max_samples_per_language = max_samples_per_language + self.dataset_cache = {} + + # Initialize language mapping + self.language_to_id = {lang: idx for idx, lang in enumerate(languages)} + self.id_to_language = {idx: lang for lang, idx in self.language_to_id.items()} + + # Load FLEURS dataset for multiple languages + self.load_datasets() + + def find_dataset_path(self, lang: str) -> Optional[str]: + """Find the dataset path in the local datasets directory structure""" + # Dataset root directory + dataset_root = os.path.join("fleurs", "datasets", "google__fluers", lang) + + if not os.path.exists(dataset_root): + print(f"Dataset directory not found for language {lang}") + return None + + # Look for version directories (e.g., 2.0.0) + versions = [d for d in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, d))] + if not versions: + print(f"No dataset versions found for language {lang}") + return None + + # Use the latest version + latest_version = sorted(versions)[-1] + version_path = os.path.join(dataset_root, latest_version) + + # Look for hash directory + try: + hash_dirs = [d for d in os.listdir(version_path) if os.path.isdir(os.path.join(version_path, d))] + if not hash_dirs: + print(f"No hash directory found for language {lang}") + return None + + # Use the first hash directory found + hash_dir = hash_dirs[0] + dataset_path = os.path.join(version_path, hash_dir) + + # Verify that necessary files exist + required_files = [ + "dataset_info.json", + "fleurs-validation.arrow", + "fleurs-test.arrow" + ] + + # Check for at least one training shard + train_shard_found = False + for file in os.listdir(dataset_path): + if file.startswith("fleurs-train-") and file.endswith(".arrow"): + train_shard_found = True + break + + if not train_shard_found: + print(f"No training shards found for language {lang}") + return None + + for file in required_files: + if not os.path.exists(os.path.join(dataset_path, file)): + print(f"Missing required file {file} for language {lang}") + return None + + print(f"Found dataset for {lang} at {dataset_path}") + return dataset_path + + except Exception as e: + print(f"Error accessing hash directory for language {lang}: {str(e)}") + return None + + def get_num_shards(self, dataset_path: str) -> int: + """Determine the number of training shards in the dataset""" + shard_files = [f for f in os.listdir(dataset_path) if f.startswith("fleurs-train-") and f.endswith(".arrow")] + if not shard_files: + return 0 + + # Extract the total number of shards from the filename pattern + # Example: "fleurs-train-00000-of-00004.arrow" -> 4 + sample_file = shard_files[0] + try: + total_shards = int(sample_file.split("-of-")[1].split(".")[0]) + return total_shards + except: + return len(shard_files) + + def load_local_dataset(self, lang: str) -> Optional[Dataset]: + """Load dataset from local directory""" + try: + # Find the dataset path + dataset_path = self.find_dataset_path(lang) + if dataset_path is None: + return None + + # Load the dataset based on the data type + if self.data_type == "train": + # Determine number of shards + num_shards = self.get_num_shards(dataset_path) + if num_shards == 0: + print(f"No training shards found for language {lang}") + return None + + # Load all available shards + shards = [] + for i in range(num_shards): + shard_path = os.path.join(dataset_path, f"fleurs-train-{i:05d}-of-{num_shards:05d}.arrow") + if os.path.exists(shard_path): + try: + # Load dataset shard + shard_dataset = Dataset.from_file(shard_path) + shards.append(shard_dataset) + except Exception as e: + print(f"Error loading shard {i} for language {lang}: {str(e)}") + else: + print(f"Warning: Missing shard {i} for language {lang}") + + if not shards: + print(f"No training shards found for language {lang}") + return None + + # Concatenate all shards + try: + dataset = concatenate_datasets(shards) + return dataset + except Exception as e: + print(f"Error concatenating shards for language {lang}: {str(e)}") + return None + else: + # For validation and test, we have single files + file_path = os.path.join(dataset_path, f"fleurs-{self.data_type}.arrow") + if not os.path.exists(file_path): + print(f"Dataset file not found: {file_path}") + return None + try: + dataset = Dataset.from_file(file_path) + return dataset + except Exception as e: + print(f"Error loading {self.data_type} dataset for language {lang}: {str(e)}") + return None + + except Exception as e: + print(f"Error loading dataset for language {lang}: {str(e)}") + return None + + def load_datasets(self): + """Load datasets for all specified languages""" + for lang in tqdm(self.languages, desc="Loading languages"): + try: + # Load local dataset for the language + dataset = self.load_local_dataset(lang) + + if dataset is None: + continue + + # Apply sampling if specified + if self.max_samples_per_language: + dataset = dataset.select(range(min(len(dataset), self.max_samples_per_language))) + + self.dataset_cache[lang] = dataset + print(f"Loaded {len(dataset)} samples for {lang}") + except Exception as e: + print(f"Error loading dataset for {lang}: {str(e)}") + + def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: + """Process audio data to extract features using NumPy-based extraction""" + try: + # Import librosa here to ensure it's available + import librosa + + if sampling_rate != self.config.speech_config['sample_rate']: + # Resample if necessary + audio_data = librosa.resample( + y=audio_data, + orig_sr=sampling_rate, + target_sr=self.config.speech_config['sample_rate'] + ) + + # Trim silence + audio_data, _ = librosa.effects.trim(audio_data, top_db=30) + + # Extract features using the speech featurizer + features = self.speech_featurizer.extract(audio_data) + + # Ensure the features are in the correct range + if np.isnan(features).any() or np.isinf(features).any(): + print("Warning: NaN or Inf values in features, replacing with zeros") + features = np.nan_to_num(features, 0) + + return features + + except Exception as e: + print(f"Error in prepare_audio: {str(e)}") + # Return empty features with correct shape as fallback + return np.zeros((1, self.config.speech_config['num_feature_bins'])) + + def get_batch_generator(self, batch_size: int): + """Generate batches of data""" + while True: + for lang in self.languages: + if lang not in self.dataset_cache: + continue + + dataset = self.dataset_cache[lang] + indices = list(range(len(dataset))) + + if self.config.dataset_config.get('shuffle', True): + np.random.shuffle(indices) + + for i in range(0, len(indices), batch_size): + batch_indices = indices[i:i + batch_size] + + features_list = [] + labels = [] + + # Process items one at a time + for idx in batch_indices: + try: + # Get item and ensure it's a dictionary + item = dataset[idx] + if not isinstance(item, dict): + item = dict(item) + + # Process audio data + audio_data = item['audio'] + if isinstance(audio_data, dict): + audio_array = audio_data.get('array') + sampling_rate = audio_data.get('sampling_rate') + else: + print(f"Unexpected audio data format for item {idx} in {lang}") + continue + + if audio_array is None or sampling_rate is None: + print(f"Missing audio data or sampling rate for item {idx} in {lang}") + continue + + # Convert audio array if needed + if isinstance(audio_array, (list, tuple)): + audio_array = np.array(audio_array) + + # Extract features + features = self.prepare_audio(audio_array, sampling_rate) + features_list.append(features) + + # Get language label + labels.append(self.language_to_id[lang]) + except Exception as e: + print(f"Error processing item {idx} in {lang}: {str(e)}") + continue + + if not features_list: + continue + + # Pad features to same length + max_len = max(feat.shape[0] for feat in features_list) + padded_features = np.zeros((len(features_list), max_len, features_list[0].shape[1])) + + for j, feat in enumerate(features_list): + padded_features[j, :feat.shape[0], :] = feat + + yield { + 'features': padded_features, + 'input_lengths': np.array([len(feat) for feat in features_list]), + 'labels': np.array(labels) + } + + def save_language_mapping(self, file_path: str): + """Save language to ID mapping""" + with open(file_path, 'w') as f: + json.dump({ + 'language_to_id': self.language_to_id, + 'id_to_language': self.id_to_language + }, f, indent=2) + + @classmethod + def load_language_mapping(cls, file_path: str) -> Dict: + """Load language to ID mapping""" + with open(file_path, 'r') as f: + return json.load(f) + + def get_num_languages(self) -> int: + """Get total number of languages""" + return len(self.languages) + + def get_language_list(self) -> List[str]: + """Get list of all languages""" + return self.languages.copy() \ No newline at end of file diff --git a/predict_by_pb.py b/predict_by_pb.py index 9be4dc6..1495b27 100644 --- a/predict_by_pb.py +++ b/predict_by_pb.py @@ -1,30 +1,86 @@ -from signal import signal +import os +# Force CPU only +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +# Must import tensorflow after setting environment variables import tensorflow as tf -gpus = tf.config.list_physical_devices('GPU') -tf.config.set_visible_devices(gpus[0:1], 'GPU') +print("TensorFlow version:", tf.__version__) +print("Using CPU only") + from vocab.vocab import Vocab import librosa import numpy as np import sys -import os from tqdm import tqdm from sklearn.metrics import accuracy_score +def load_model(model_path): + try: + # Load model in CPU mode + with tf.device('/CPU:0'): + return tf.saved_model.load(model_path) + except Exception as e: + print(f"Error loading model: {str(e)}") + return None -vocab = Vocab("vocab/vocab.txt") -model = tf.saved_model.load('saved_models/lang14/pb/2/') - - -def predict_wav(wav_path): - signal, _ = librosa.load(wav_path, sr=16000) - output, prob = model.predict_pb(signal) - language = vocab.token_list[output.numpy()] - print(language, prob.numpy()*100) - - return output.numpy(), prob.numpy() - +def predict_wav(wav_path, model, vocab): + try: + # Load and preprocess audio + signal, _ = librosa.load(wav_path, sr=16000) + + # Convert to tensor and ensure CPU operation + with tf.device('/CPU:0'): + signal = tf.convert_to_tensor(signal, dtype=tf.float32) + + # Make prediction + if hasattr(model, 'predict_pb'): + output = model.predict_pb(signal) + else: + output = model(signal) + + if isinstance(output, dict): + pred = output.get("output_0", None) + prob = output.get("output_1", None) + elif isinstance(output, (list, tuple)) and len(output) == 2: + pred, prob = output + else: + print("Unexpected model output format") + return None, None + + # Get prediction + pred_idx = tf.argmax(pred).numpy() + probability = tf.reduce_max(tf.nn.softmax(prob)).numpy() + + language = vocab.token_list[pred_idx] + print(f"Detected language: {language} (confidence: {probability*100:.2f}%)") + + return pred_idx, probability + + except Exception as e: + print(f"Error during prediction: {str(e)}") + return None, None if __name__ == '__main__': - wav_path = sys.argv[1] - predict_wav(wav_path) - + try: + # Initialize vocabulary + vocab = Vocab("vocab/vocab.txt") + + # Load model + print("Loading model...") + model = load_model('saved_models/lang14/pb/2/') + + if model is None: + print("Failed to load model. Exiting.") + sys.exit(1) + + # Make prediction + print("Making prediction...") + predict_wav("test_audios/french.wav", model, vocab) + + except Exception as e: + print(f"Error: {str(e)}") + print("\nTroubleshooting tips:") + print("1. Make sure the model file exists in saved_models/lang14/pb/2/") + print("2. Make sure the vocabulary file exists in vocab/vocab.txt") + print("3. Make sure the audio file exists in test_audios/french.wav") diff --git a/requirements.txt b/requirements.txt index d3be841..22f8fad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,12 @@ -tensorflow==2.4.1 -tensorflow-gpu==2.4.1 -tensorflow-addons==0.15.0 -matplotlib==3.5.0 -numpy==1.19.5 -scikit-learn==1.0.1 -librosa==0.8.1 -SoundFile==0.10.3.post1 -PyYAML==6.0 \ No newline at end of file +tensorflow-cpu==2.11.0 +numpy==1.23.5 +librosa==0.10.1 +soundfile==0.12.1 +matplotlib==3.7.1 +scikit-learn==1.2.2 +PyYAML==6.0 +tqdm>=4.65.0 +pandas>=2.0.0 +datasets>=2.12.0 +huggingface-hub>=0.16.4 +transformers>=4.30.0 diff --git a/saved_models/.DS_Store b/saved_models/.DS_Store new file mode 100644 index 0000000..aa05c57 Binary files /dev/null and b/saved_models/.DS_Store differ diff --git a/saved_models/lang14/.DS_Store b/saved_models/lang14/.DS_Store new file mode 100644 index 0000000..12f0c23 Binary files /dev/null and b/saved_models/lang14/.DS_Store differ diff --git a/saved_models/lang14/pb/.DS_Store b/saved_models/lang14/pb/.DS_Store new file mode 100644 index 0000000..60b820d Binary files /dev/null and b/saved_models/lang14/pb/.DS_Store differ diff --git a/saved_models/lang14/pb/2/.DS_Store b/saved_models/lang14/pb/2/.DS_Store new file mode 100644 index 0000000..19ad97c Binary files /dev/null and b/saved_models/lang14/pb/2/.DS_Store differ diff --git a/test_tf.py b/test_tf.py new file mode 100644 index 0000000..f560870 --- /dev/null +++ b/test_tf.py @@ -0,0 +1,13 @@ +import tensorflow as tf + +cifar = tf.keras.datasets.cifar100 +(x_train, y_train), (x_test, y_test) = cifar.load_data() +model = tf.keras.applications.ResNet50( + include_top=True, + weights=None, + input_shape=(32, 32, 3), + classes=100,) + +loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) +model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]) +model.fit(x_train, y_train, epochs=5, batch_size=64) diff --git a/train.py b/train.py index f931cec..4d7d11a 100644 --- a/train.py +++ b/train.py @@ -3,8 +3,16 @@ import argparse import tensorflow as tf -gpus = tf.config.list_physical_devices('GPU') -# tf.config.set_visible_devices(gpus[0:1], 'GPU') + +# Modern GPU memory growth configuration +physical_devices = tf.config.list_physical_devices('GPU') +if physical_devices: + try: + for device in physical_devices: + tf.config.experimental.set_memory_growth(device, True) + except RuntimeError as e: + print(e) + import datetime import time import os @@ -18,258 +26,213 @@ from dataset import create_dataset import tensorflow_addons as tfa from sklearn.metrics import f1_score, recall_score, precision_score -mirrored_strategy = tf.distribute.MirroredStrategy() +# Use MirroredStrategy for multi-GPU training +strategy = tf.distribute.MirroredStrategy() def train(config_file): config = Config(config_file) current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") dir_log_root = "./saved_weights/" if not os.path.exists(dir_log_root): - os.mkdir(dir_log_root) - dir_current = dir_log_root + current_time - if not os.path.isdir(dir_log_root): - os.mkdir(dir_log_root) - if not os.path.isdir(dir_current): - os.mkdir(dir_current) - copyfile(config_file, dir_current + '/config.yml') - log_file = open(dir_current + '/log.txt', 'w') - copyfile(config.dataset_config['vocabulary'], dir_current + '/vocab.txt') + os.makedirs(dir_log_root) + dir_current = os.path.join(dir_log_root, current_time) + if not os.path.exists(dir_current): + os.makedirs(dir_current) + os.makedirs(os.path.join(dir_current, 'best')) + os.makedirs(os.path.join(dir_current, 'last')) + + copyfile(config_file, os.path.join(dir_current, 'config.yml')) + log_file = open(os.path.join(dir_current, 'log.txt'), 'w') + copyfile(config.dataset_config['vocabulary'], os.path.join(dir_current, 'vocab.txt')) - config.print() log_file.write(config.toString()) - # vocab_file.write(config.toString()) log_file.flush() vocab = Vocab(config.dataset_config['vocabulary']) batch_size = config.running_config['batch_size'] - global_batch_size = batch_size * mirrored_strategy.num_replicas_in_sync + global_batch_size = batch_size * strategy.num_replicas_in_sync speech_featurizer = NumpySpeechFeaturizer(config.speech_config) - model = Model(**config.model_config, vocab_size=len(vocab.token_list)) - if config.running_config['load_weights'] is not None: - model.load_weights(config.running_config['load_weights']) - model.add_featurizers(speech_featurizer) - model.init_build([None, config.speech_config['num_feature_bins']]) - model.summary() - train_dataset = create_dataset(batch_size=global_batch_size, - load_type=config.dataset_config['load_type'], - data_type=config.dataset_config['train'], - speech_featurizer=speech_featurizer, - config = config, - vocab = vocab) - eval_dataset = create_dataset(batch_size=global_batch_size, - load_type=config.dataset_config['load_type'], - data_type=config.dataset_config['dev'], - speech_featurizer=speech_featurizer, - config = config, - vocab = vocab) - test_dataset = create_dataset(batch_size=global_batch_size, - load_type=config.dataset_config['load_type'], - data_type=config.dataset_config['test'], - speech_featurizer=speech_featurizer, - config = config, - vocab = vocab) - train_dist_batch = mirrored_strategy.experimental_distribute_dataset(train_dataset) - dev_dist_batch = mirrored_strategy.experimental_distribute_dataset(eval_dataset) - test_dist_batch = mirrored_strategy.experimental_distribute_dataset(test_dataset) - dev_loss = tf.keras.metrics.Mean(name='dev_loss') + with strategy.scope(): + model = Model(**config.model_config, vocab_size=len(vocab.token_list)) + if config.running_config['load_weights'] is not None: + model.load_weights(config.running_config['load_weights']) + model.add_featurizers(speech_featurizer) + model.init_build([None, config.speech_config['num_feature_bins']]) + model.summary() + + # Use modern optimizers with learning rate schedules + learning_rate = tf.keras.optimizers.schedules.CosineDecay( + config.optimizer_config['max_lr'], + decay_steps=config.running_config['num_epochs'] * config.running_config['train_steps'] + ) + optimizer = tf.keras.optimizers.AdamW( + learning_rate=learning_rate, + weight_decay=0.01 + ) + + # Modern loss functions + loss_fn = tfa.losses.SigmoidFocalCrossEntropy( + from_logits=True, + alpha=0.25, + gamma=2.0, + reduction=tf.keras.losses.Reduction.NONE + ) + loss_fn_smooth = tf.keras.losses.CategoricalCrossentropy( + from_logits=True, + label_smoothing=0.1, + reduction=tf.keras.losses.Reduction.NONE + ) + + # Create datasets + train_dataset = create_dataset( + batch_size=global_batch_size, + load_type=config.dataset_config['load_type'], + data_type=config.dataset_config['train'], + speech_featurizer=speech_featurizer, + config=config, + vocab=vocab + ) + eval_dataset = create_dataset( + batch_size=global_batch_size, + load_type=config.dataset_config['load_type'], + data_type=config.dataset_config['dev'], + speech_featurizer=speech_featurizer, + config=config, + vocab=vocab + ) + test_dataset = create_dataset( + batch_size=global_batch_size, + load_type=config.dataset_config['load_type'], + data_type=config.dataset_config['test'], + speech_featurizer=speech_featurizer, + config=config, + vocab=vocab + ) + + # Distribute datasets + train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) + eval_dist_dataset = strategy.experimental_distribute_dataset(eval_dataset) + test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset) + + # Metrics train_loss = tf.keras.metrics.Mean(name='train_loss') - dev_accuracy = tf.keras.metrics.Mean(name='train_accuracy') - init_steps = config.optimizer_config['init_steps'] - step = tf.Variable(init_steps) - - optimizer = tf.keras.optimizers.Adam(lr=config.optimizer_config['max_lr']) - ckpt = tf.train.Checkpoint(step=step, optimizer=optimizer, model=model) - ckpt_manager = tf.train.CheckpointManager(ckpt, dir_current + '/ckpt', max_to_keep=5) - loss_object = tfa.losses.SigmoidFocalCrossEntropy( - from_logits = True, - alpha = 0.25, - gamma = 0, - reduction = tf.keras.losses.Reduction.NONE) - loss_object_label_smooth = tf.keras.losses.CategoricalCrossentropy( - from_logits=True, label_smoothing=0.1, reduction=tf.keras.losses.Reduction.NONE) - - def compute_loss(real, pred, smooth=False): - if smooth: - loss_ = loss_object_label_smooth(tf.one_hot(real, len(vocab.token_list)), pred) - else: - real = tf.one_hot(real, len(vocab.token_list)) - loss_ = loss_object(real, pred) - return tf.nn.compute_average_loss(loss_, global_batch_size=global_batch_size) + val_loss = tf.keras.metrics.Mean(name='val_loss') + train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy') + val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy') - def accuracy_function(real, pred): - pred = tf.cast(pred, dtype=tf.int32) - accuracies = tf.equal(real, pred) - - mask = tf.math.logical_not(tf.math.equal(real, 0)) - accuracies = tf.math.logical_and(mask, accuracies) - - accuracies = tf.cast(accuracies, dtype=tf.float32) - mask = tf.cast(mask, dtype=tf.float32) - - return tf.reduce_sum(accuracies)/tf.reduce_sum(mask) + def compute_loss(labels, predictions, smooth=False): + per_example_loss = loss_fn_smooth(tf.one_hot(labels, len(vocab.token_list)), predictions) if smooth else \ + loss_fn(tf.one_hot(labels, len(vocab.token_list)), predictions) + return tf.nn.compute_average_loss(per_example_loss, global_batch_size=global_batch_size) @tf.function - def train_step(input, input_length, target): + def train_step(inputs): + x, x_len, y = inputs + with tf.GradientTape() as tape: - predictions = model([input, input_length], training=True) - loss = compute_loss(target, predictions, smooth=True) - grads = tape.gradient(loss, model.trainable_variables) - optimizer.apply_gradients(zip(grads, model.trainable_variables)) + predictions = model([x, x_len], training=True) + loss = compute_loss(y, predictions, smooth=True) + + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + train_loss.update_state(loss) + train_accuracy.update_state(y, predictions) return loss @tf.function - def dev_step(input, input_length, target): - predictions = model([input, input_length], training=False) - t_loss = compute_loss(target, predictions, smooth=True) + def test_step(inputs): + x, x_len, y = inputs + predictions = model([x, x_len], training=False) + loss = compute_loss(y, predictions, smooth=True) - return t_loss, predictions + val_loss.update_state(loss) + val_accuracy.update_state(y, predictions) + return predictions, y @tf.function - def test_step(input, input_length, target): - predictions = model([input, input_length], training=False) - return predictions, target - - @tf.function(experimental_relax_shapes=True) - def distributed_train_step(x, x_len, y): - per_replica_losses = mirrored_strategy.run(train_step, args=(x, x_len, y)) - mean_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) - return mean_loss - - @tf.function(experimental_relax_shapes=True) - def distributed_dev_step(x, x_len, y): - per_replica_losses, per_replica_preds = mirrored_strategy.run(dev_step, args=(x, x_len, y)) - mean_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) - return mean_loss, per_replica_preds - - - @tf.function(experimental_relax_shapes=True) - def distributed_test_step(x, x_len, y): - return mirrored_strategy.run(test_step, args=(x, x_len, y)) + def distributed_train_step(inputs): + per_replica_losses = strategy.run(train_step, args=(inputs,)) + return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None) - plot_train_loss = [] - plot_dev_loss = [] - plot_acc, plot_precision = [], [] - best_acc= 0 - train_iter = iter(train_dist_batch) - dev_iter = iter(dev_dist_batch) - test_iter = iter(test_dist_batch) + @tf.function + def distributed_test_step(inputs): + return strategy.run(test_step, args=(inputs,)) + # Training loop + best_accuracy = 0 for epoch in range(1, config.running_config['num_epochs'] + 1): - if config.dataset_config['load_type']=='txt': - train_iter = iter(train_dist_batch) - dev_iter = iter(dev_dist_batch) - test_iter = iter(test_dist_batch) - start = time.time() - # training loop - train_loss = 0.0 - dev_loss = 0.0 - for train_batches in range(config.running_config['train_steps']): - inp, inp_len, target = next(train_iter) - train_loss += distributed_train_step(inp, inp_len, target) - template = '\rEpoch {} Step {} Loss {:.4f}' - print(colored(template.format( - epoch, train_batches + 1, train_loss / (train_batches + 1), - ), 'green'), end='', flush=True) - step.assign_add(1) - - # validation loop - pred_all = tf.zeros([1], dtype=tf.int32) - true_all = tf.zeros([1], dtype=tf.int32) - for dev_batches in range(config.running_config['dev_steps']): - inp, inp_len, target = next(dev_iter) - loss, predicted_result = distributed_dev_step(inp, inp_len, target) - dev_loss += loss - if mirrored_strategy.num_replicas_in_sync == 1: - prediction = tf.nn.softmax(predicted_result) - y_pred = tf.argmax(prediction, axis=-1) - y_pred = tf.cast(y_pred, dtype=tf.int32) - pred_all = tf.concat([pred_all, y_pred], axis=0) - true_all = tf.concat([true_all, target], axis=0) - else: - for i in range(mirrored_strategy.num_replicas_in_sync): - predicted_result_per_replica = predicted_result.values[i] - y_true = target.values[i] - y_pred = tf.argmax(predicted_result_per_replica, axis=-1) - y_pred = tf.cast(y_pred, dtype=tf.int32) - pred_all = tf.concat([pred_all, y_pred], axis=0) - true_all = tf.concat([true_all, y_true], axis=0) - dev_accuracy = accuracy_function(true_all, pred_all) - - pred_all = tf.zeros([1], dtype=tf.int32) - true_all = tf.zeros([1], dtype=tf.int32) - for test_batches in range(config.running_config['test_steps']): - inp, inp_len, target = next(test_iter) - predicted_result, target_result = distributed_test_step(inp, inp_len, target) - if mirrored_strategy.num_replicas_in_sync == 1: - prediction = tf.nn.softmax(predicted_result) - y_pred =tf.argmax(prediction, axis=-1) - y_pred = tf.cast(y_pred, dtype=tf.int32) - pred_all = tf.concat([pred_all, y_pred], axis=0) - true_all = tf.concat([true_all, target], axis=0) - else: - for replica in range(mirrored_strategy.num_replicas_in_sync): - predicted_result_per_replica = predicted_result.values[i] - y_true = target.values[i] - y_pred = tf.argmax(predicted_result_per_replica, axis=-1) - y_pred = tf.cast(y_pred, dtype=tf.int32) - pred_all = tf.concat([pred_all, y_pred], axis=0) - true_all = tf.concat([true_all, y_true], axis=0) + start_time = time.time() - test_acc = accuracy_function(real=true_all, pred=pred_all) - - test_f1 = f1_score(y_true=true_all, y_pred=pred_all, average='macro') - precision = precision_score(y_true=true_all, y_pred=pred_all, average='macro', zero_division=1) - recall = recall_score(y_true=true_all, y_pred=pred_all, average='macro') - if precision > best_acc: - best_acc = precision - model.save_weights(dir_current + '/best/' + 'model') - model.save_weights(dir_current + '/last/' + 'model') - template = ("\rEpoch {}, Loss: {:.4f}, Val Loss: {:.4f}, " - "Val Acc: {:.4f}, test ACC: {:.4f},F1: {:.4f}, precision: {:.4f}, recall: {:.4f}, Time Cost: {:.2f} sec") - text = template.format(epoch, train_loss / config.running_config['train_steps'], - dev_loss/ config.running_config['dev_steps'], dev_accuracy *100, - test_acc*100, test_f1*100, precision*100, recall*100, time.time() - start) - print(colored(text, 'cyan')) - log_file.write(text) + # Reset metrics + train_loss.reset_states() + val_loss.reset_states() + train_accuracy.reset_states() + val_accuracy.reset_states() + + # Training + for step, inputs in enumerate(train_dist_dataset): + loss = distributed_train_step(inputs) + if step % 10 == 0: + template = 'Epoch {}, Step {}, Loss: {:.4f}, Accuracy: {:.4f}' + print(colored(template.format( + epoch, step + 1, + train_loss.result(), + train_accuracy.result() + ), 'green')) + + # Validation + all_predictions = [] + all_labels = [] + for inputs in eval_dist_dataset: + predictions, labels = distributed_test_step(inputs) + all_predictions.extend(tf.argmax(predictions, axis=-1).numpy()) + all_labels.extend(labels.numpy()) + + # Calculate metrics + val_f1 = f1_score(y_true=all_labels, y_pred=all_predictions, average='macro') + val_precision = precision_score(y_true=all_labels, y_pred=all_predictions, average='macro', zero_division=1) + val_recall = recall_score(y_true=all_labels, y_pred=all_predictions, average='macro') + + # Save best model + if val_precision > best_accuracy: + best_accuracy = val_precision + model.save_weights(os.path.join(dir_current, 'best', 'model')) + model.save_weights(os.path.join(dir_current, 'last', 'model')) + + # Log results + template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.4f}, Val Loss: {:.4f}, Val Accuracy: {:.4f}, F1: {:.4f}, Precision: {:.4f}, Recall: {:.4f}, Time: {:.2f}s' + print(template.format( + epoch, + train_loss.result(), + train_accuracy.result(), + val_loss.result(), + val_accuracy.result(), + val_f1, + val_precision, + val_recall, + time.time() - start_time + )) + log_file.write(template.format( + epoch, + train_loss.result(), + train_accuracy.result(), + val_loss.result(), + val_accuracy.result(), + val_f1, + val_precision, + val_recall, + time.time() - start_time + ) + '\n') log_file.flush() - plot_train_loss.append(train_loss / config.running_config['train_steps']) - plot_dev_loss.append(dev_loss / config.running_config['dev_steps']) - plot_acc.append(test_acc) - plot_precision.append(precision) - ckpt_manager.save() - - plt.plot(plot_train_loss, '-r', label='train_loss') - plt.title('Train Loss') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/loss.png') - #plot dev - plt.clf() - plt.plot(plot_dev_loss, '-g', label='dev_loss') - plt.title('dev Loss') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/dev_loss.png') - - # plot acc curve - plt.clf() - plt.plot(plot_acc, 'b-', label='acc') - plt.title('Accuracy') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/acc.png') - # plot f1 curve - plt.clf() - plt.plot(plot_precision, 'y-', label='f1-score') - plt.title('F1') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/f1-score.png') - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Spoken_language_identification Model training") parser.add_argument("--config_file", type=str, default='./configs/config.yml', help="Config File Path") args = parser.parse_args() kwargs = vars(args) - with mirrored_strategy.scope(): + with strategy.scope(): train(**kwargs) \ No newline at end of file diff --git a/train_multilingual.py b/train_multilingual.py new file mode 100644 index 0000000..87cfa73 --- /dev/null +++ b/train_multilingual.py @@ -0,0 +1,164 @@ +import os +import json +import tensorflow as tf +from tqdm import tqdm +from datetime import datetime +from multilingual_dataset import MultilingualDataset +from featurizers.speech_featurizers import NumpySpeechFeaturizer +from configs.config import Config +from vocab.vocab import Vocab + +def setup_mixed_precision(): + """Setup mixed precision for better performance on ARM32""" + policy = tf.keras.mixed_precision.Policy('mixed_float16') + tf.keras.mixed_precision.set_global_policy(policy) + +class ExpandDimsLayer(tf.keras.layers.Layer): + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, inputs): + return tf.expand_dims(inputs, axis=self.axis) + +class SqueezeLayer(tf.keras.layers.Layer): + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, inputs): + return tf.squeeze(inputs, axis=self.axis) + +def create_model(config, num_languages): + """Create the model with support for multiple languages""" + inputs = tf.keras.Input(shape=(None, config.speech_config['num_feature_bins'])) + x = inputs + + # CNN layers + for filters, kernel in zip(config.model_config['filters'], config.model_config['kernel_size']): + x = ExpandDimsLayer(axis=-1)(x) + x = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=kernel, + padding='same', + activation='relu' + )(x) + x = tf.keras.layers.BatchNormalization()(x) + x = SqueezeLayer(axis=-1)(x) + + # BiLSTM layers + x = tf.keras.layers.Bidirectional( + tf.keras.layers.LSTM( + config.model_config['rnn_cell'], + return_sequences=True + ) + )(x) + + # Global pooling + x = tf.keras.layers.GlobalAveragePooling1D()(x) + + # Output layer + outputs = tf.keras.layers.Dense(num_languages, activation='softmax')(x) + + model = tf.keras.Model(inputs=inputs, outputs=outputs) + return model + +def main(): + # Load configuration + config = Config("configs/config.yml") + + # Setup mixed precision if enabled + if config.optimizer_config.get('use_mixed_precision', False): + setup_mixed_precision() + + # Load language configuration + with open(config.dataset_config['languages_file'], 'r') as f: + languages_config = json.load(f) + languages = languages_config['supported_languages'] + + # Initialize components + vocab = Vocab(config.dataset_config['vocabulary']) + speech_featurizer = NumpySpeechFeaturizer(config.speech_config) + + # Create datasets + train_dataset = MultilingualDataset( + config=config, + languages=languages, + vocab=vocab, + speech_featurizer=speech_featurizer, + data_type='train', + max_samples_per_language=config.dataset_config['max_samples_per_language'] + ) + + val_dataset = MultilingualDataset( + config=config, + languages=languages, + vocab=vocab, + speech_featurizer=speech_featurizer, + data_type='validation', + max_samples_per_language=config.dataset_config['max_samples_per_language'] // 10 + ) + + # Create model + model = create_model(config, len(languages)) + + # Compile model + optimizer = tf.keras.optimizers.Adam( + learning_rate=config.optimizer_config['max_lr'], + beta_1=config.optimizer_config['beta1'], + beta_2=config.optimizer_config['beta2'], + epsilon=config.optimizer_config['epsilon'] + ) + + model.compile( + optimizer=optimizer, + loss='sparse_categorical_crossentropy', + metrics=['accuracy'] + ) + + # Setup callbacks + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + save_dir = os.path.join('saved_weights', 'multilingual', timestamp) + os.makedirs(save_dir, exist_ok=True) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(save_dir, 'epoch_{epoch:02d}'), + save_weights_only=True, + save_freq='epoch' + ), + tf.keras.callbacks.TensorBoard( + log_dir=os.path.join('logs', timestamp), + update_freq='batch' + ), + tf.keras.callbacks.EarlyStopping( + monitor='val_loss', + patience=5, + restore_best_weights=True + ) + ] + + # Train model + train_generator = train_dataset.get_batch_generator(config.running_config['batch_size']) + val_generator = val_dataset.get_batch_generator(config.running_config['batch_size']) + + steps_per_epoch = config.running_config['train_steps'] + validation_steps = config.running_config['dev_steps'] + + model.fit( + train_generator, + steps_per_epoch=steps_per_epoch, + validation_data=val_generator, + validation_steps=validation_steps, + epochs=config.running_config['num_epochs'], + callbacks=callbacks + ) + + # Save final model + model.save_weights(os.path.join(save_dir, 'final')) + + # Save language mapping + train_dataset.save_language_mapping(os.path.join(save_dir, 'language_mapping.json')) + +if __name__ == '__main__': + main() \ No newline at end of file