diff --git a/modal_infer.py b/modal_infer.py index c61e9d3..daf0a25 100644 --- a/modal_infer.py +++ b/modal_infer.py @@ -1,15 +1,14 @@ from __future__ import annotations import argparse -import contextlib import io import logging -import subprocess import sys -from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime from pathlib import Path, PurePosixPath +import subprocess +from typing import Any, Dict, List, Optional, Sequence, Tuple from uuid import uuid4 @@ -45,7 +44,7 @@ def ensure_utf8_stdio() -> None: try: import modal -except ImportError: # pragma: no cover +except ImportError as exc: # pragma: no cover print("未检测到 modal 包,请先运行 `python -m pip install modal questionary`。") raise @@ -86,6 +85,7 @@ def ensure_utf8_stdio() -> None: "H200", "B200", ] +DEFAULT_CLS_CONCURRENCY = 4 def resolve_resource_path(filename: str) -> Path: @@ -97,8 +97,8 @@ def resolve_resource_path(filename: str) -> Path: class ModelProfile: key: str label: str - hf_repo: str | None - target_dir: str | None + hf_repo: Optional[str] + target_dir: Optional[str] description: str @@ -108,11 +108,12 @@ class UserSelection: gpu_choice: str input_path: Path model_profile: ModelProfile - custom_repo: str | None - custom_target_dir: str | None + custom_repo: Optional[str] + custom_target_dir: Optional[str] enable_batching: bool - batch_size: int | None + batch_size: Optional[int] max_batch_size: int + cls_concurrency: int timeout_minutes: int @@ -121,17 +122,18 @@ class UploadManifest: session_id: str source_type: str # file or directory local_source: Path - remote_inputs_rel: list[Path] + remote_inputs_rel: List[Path] remote_output_rel: Path local_output_dir: Path remote_logs_rel: Path - original_filename: str | None = None # 原始文件名(用于恢复空格) + original_filename: Optional[str] = None # 原始文件名(用于恢复空格) + original_rel_path: Optional[Path] = None # 原始文件的相对路径(用于重建目录结构) @dataclass class ScanResult: - audio_files: list[Path] - mp4_files: list[Path] + audio_files: List[Path] + mp4_files: List[Path] class NoAudioFilesError(Exception): @@ -164,7 +166,7 @@ def container_to_volume_path(container_path: str) -> str: return rel -MODEL_PRESETS: dict[str, ModelProfile] = { +MODEL_PRESETS: Dict[str, ModelProfile] = { "chickenrice": ModelProfile( key="chickenrice", label="海南鸡(日文转中文优化)", @@ -214,7 +216,9 @@ def setup_logger() -> Path: def ensure_questionary(): if questionary is None or Choice is None: - raise RuntimeError("需要 questionary,请运行 `python -m pip install questionary`。") + raise RuntimeError( + "需要 questionary,请运行 `python -m pip install questionary`。" + ) def ask_selection() -> UserSelection: @@ -229,7 +233,10 @@ def ask_selection() -> UserSelection: model_key = questionary.select( "选择模型:", - choices=[Choice(title=profile.label, value=key) for key, profile in MODEL_PRESETS.items()], + choices=[ + Choice(title=profile.label, value=key) + for key, profile in MODEL_PRESETS.items() + ], ).ask() if not model_key: raise KeyboardInterrupt @@ -241,7 +248,9 @@ def ask_selection() -> UserSelection: custom_repo = questionary.text("输入 HuggingFace repo(例如 user/repo)").ask() if not custom_repo: raise KeyboardInterrupt - custom_target_dir = questionary.text("输入 models 子目录名称(英文/数字)", default="custom-model").ask() + custom_target_dir = questionary.text( + "输入 models 子目录名称(英文/数字)", default="custom-model" + ).ask() if not custom_target_dir: raise KeyboardInterrupt @@ -252,20 +261,35 @@ def ask_selection() -> UserSelection: if not input_path.exists(): raise FileNotFoundError(f"路径不存在:{input_path}") - enable_batching = questionary.confirm("启用批处理以加速(需要更高显存)?", default=False).ask() + enable_batching = questionary.confirm( + "启用批处理以加速(需要更高显存)?", default=False + ).ask() if enable_batching is None: raise KeyboardInterrupt batch_size = None max_batch_size = 8 if enable_batching: - batch_size_str = questionary.text("指定批次大小(留空自动探测)", default="").ask() + batch_size_str = questionary.text( + "指定批次大小(留空自动探测)", default="" + ).ask() if batch_size_str: batch_size = int(batch_size_str) max_batch_size_str = questionary.text("最大自动批次大小", default="8").ask() max_batch_size = int(max_batch_size_str or "8") - timeout_minutes = int(questionary.text("任务超时时间(分钟)", default="60").ask() or "60") + timeout_minutes = int( + questionary.text("任务超时时间(分钟)", default="60").ask() or "60" + ) + + cls_concurrency = int( + questionary.text( + "目录模式 cls 并发容器数", default=str(DEFAULT_CLS_CONCURRENCY) + ).ask() + or str(DEFAULT_CLS_CONCURRENCY) + ) + if cls_concurrency < 1: + raise ValueError("并发容器数必须 >= 1") return UserSelection( run_mode="once", @@ -277,14 +301,15 @@ def ask_selection() -> UserSelection: enable_batching=bool(enable_batching), batch_size=batch_size, max_batch_size=max_batch_size, + cls_concurrency=cls_concurrency, timeout_minutes=timeout_minutes, ) def scan_audio_files(path: Path) -> ScanResult: """扫描目录,返回音频文件和需要转换的 mp4 文件""" - audio_files: list[Path] = [] - mp4_files: list[Path] = [] + audio_files: List[Path] = [] + mp4_files: List[Path] = [] for file in path.rglob("*"): if file.is_file(): suffix = file.suffix.lower() @@ -311,11 +336,15 @@ def validate_audio_path(path: Path) -> ScanResult: scan_result = scan_audio_files(path) if scan_result.mp4_files: logging.warning("=" * 60) - logging.warning("发现 %d 个 mp4 文件,这些文件将被跳过:", len(scan_result.mp4_files)) + logging.warning( + "发现 %d 个 mp4 文件,这些文件将被跳过:", len(scan_result.mp4_files) + ) for mp4_file in scan_result.mp4_files: logging.warning(" - %s", mp4_file) logging.warning("请使用 ffmpeg 转换为 mp3 后再处理,例如:") - logging.warning(' ffmpeg -i "input.mp4" -vn -acodec libmp3lame "output.mp3"') + logging.warning( + ' ffmpeg -i "input.mp4" -vn -acodec libmp3lame "output.mp3"' + ) logging.warning("=" * 60) if not scan_result.audio_files: raise NoAudioFilesError(f"输入的文件夹内没有音频文件:{path}") @@ -324,11 +353,21 @@ def validate_audio_path(path: Path) -> ScanResult: raise ValueError(f"路径 {path} 既不是文件也不是文件夹。") +def generate_safe_filename(original_name: str) -> str: + """生成安全的文件名,保留原始名称和时间戳""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + stem = Path(original_name).stem + suffix = Path(original_name).suffix.lower() + # 移除特殊字符,替换空格为下划线 + safe_stem = "".join(c if c.isalnum() or c in "._-" else "_" for c in stem) + return f"{safe_stem}_{timestamp}{suffix}" + + def upload_single_file( volume: modal.Volume, selection: UserSelection, audio_file: Path, - base_dir: Path | None = None, + base_dir: Optional[Path] = None, ) -> UploadManifest: """上传单个音频文件到 Modal Volume。 @@ -342,18 +381,24 @@ def upload_single_file( remote_session_rel = Path(SESSION_SUBDIR) / session_id remote_logs_rel = remote_session_rel / "logs" - # 使用固定文件名避免全角字符等问题 + # 生成安全文件名,保留原始名称和时间戳 original_filename = audio_file.name - safe_filename = "todo" + audio_file.suffix.lower() + safe_filename = generate_safe_filename(original_filename) with volume.batch_upload(force=True) as batch: remote_rel = remote_session_rel / safe_filename - logging.info("上传文件 -> %s", rel_to_volume_path(remote_rel)) + logging.info("[上传] %s -> %s", audio_file, rel_to_volume_path(remote_rel)) batch.put_file(str(audio_file), rel_to_volume_path(remote_rel)) # 如果指定了 base_dir(文件夹模式),输出到 base_dir;否则输出到文件所在目录 local_output_dir = base_dir if base_dir else audio_file.parent + # 计算原始文件的相对路径(用于重建目录结构) + if base_dir and audio_file.is_relative_to(base_dir): + original_rel_path: Path = audio_file.relative_to(base_dir) + else: + original_rel_path = Path(audio_file.name) + return UploadManifest( session_id=session_id, source_type="file", @@ -362,15 +407,22 @@ def upload_single_file( remote_output_rel=remote_session_rel, local_output_dir=local_output_dir, remote_logs_rel=remote_logs_rel, - original_filename=original_filename, # 始终记录原始文件名 + original_filename=original_filename, # 原始文件名 + original_rel_path=original_rel_path, # 原始相对路径(用于重建目录结构) ) -def build_job_payload(selection: UserSelection, manifest: UploadManifest) -> dict: +def build_job_payload(selection: UserSelection, manifest: UploadManifest) -> Dict: model_profile = selection.model_profile - hf_repo = selection.custom_repo if model_profile.key == "custom" else model_profile.hf_repo + hf_repo = ( + selection.custom_repo + if model_profile.key == "custom" + else model_profile.hf_repo + ) target_dir = ( - selection.custom_target_dir if model_profile.key == "custom" else model_profile.target_dir + selection.custom_target_dir + if model_profile.key == "custom" + else model_profile.target_dir ) or "custom-model" payload = { @@ -418,8 +470,8 @@ def run_remote_pipeline( volume: modal.Volume, selection: UserSelection, manifest: UploadManifest, - payload: dict, -) -> dict: + payload: Dict, +) -> Dict: logging.info("=== 开始构建 Modal 镜像 ===") image = build_modal_image() logging.info("✓ 镜像构建完成") @@ -434,7 +486,7 @@ def run_remote_pipeline( volumes={str(REMOTE_MOUNT): volume}, serialized=True, ) - def modal_pipeline(job_payload: dict) -> dict: + def modal_pipeline(job_payload: Dict) -> Dict: return _remote_pipeline(job_payload) logging.info("=== 开始远程执行 ===") @@ -451,8 +503,8 @@ def modal_pipeline(job_payload: dict) -> dict: def process_directory_files( volume: modal.Volume, selection: UserSelection, - audio_files: list[Path], -) -> tuple[int, int]: + audio_files: List[Path], +) -> Tuple[int, int]: """处理文件夹中的所有音频文件,容器复用。 Args: @@ -469,56 +521,76 @@ def process_directory_files( logging.info("使用 GPU:%s", selection.gpu_choice) logging.info("超时时间:%d 分钟", selection.timeout_minutes) logging.info("待处理文件数:%d", len(audio_files)) + logging.info("目录模式并发容器数:%d", selection.cls_concurrency) app = modal.App(APP_NAME) - @app.function( + @app.cls( image=image, gpu=selection.gpu_choice, timeout=selection.timeout_minutes * 60, volumes={str(REMOTE_MOUNT): volume}, serialized=True, - min_containers=1, # 保持容器预热,复用容器 + min_containers=1, + max_containers=selection.cls_concurrency, ) - def modal_pipeline(job_payload: dict) -> dict: - return _remote_pipeline(job_payload) + class ModalPipelineWorker: + @modal.method() + def run(self, job_payload: Dict) -> Dict: + return _remote_pipeline(job_payload) success_count = 0 fail_count = 0 base_dir = selection.input_path # 文件夹模式下,输出到源文件夹 + def has_output_files(audio_file: Path) -> bool: + """检查音频文件所在目录中是否已有对应的字幕文件""" + audio_stem = audio_file.stem + audio_dir = audio_file.parent + for suffix in SUB_SUFFIXES: + output_file = audio_dir / (audio_stem + suffix) + if output_file.exists(): + return True + return False + with app.run(): + worker = ModalPipelineWorker() + futures: List[Tuple[Path, UploadManifest, Any]] = [] + for i, audio_file in enumerate(audio_files, 1): + if has_output_files(audio_file): + logging.info("[跳过] 文件 %s 已存在字幕文件,跳过处理", audio_file) + success_count += 1 + continue + logging.info("=" * 60) - logging.info("处理文件 [%d/%d]: %s", i, len(audio_files), audio_file.name) + logging.info("提交文件 [%d/%d]: %s", i, len(audio_files), audio_file) logging.info("=" * 60) try: - # 1. 上传单个文件 manifest = upload_single_file(volume, selection, audio_file, base_dir) - - # 2. 构建 payload payload = build_job_payload(selection, manifest) + future = getattr(worker.run, "spawn")(payload) + futures.append((audio_file, manifest, future)) + except Exception as e: + logging.error("✗ 文件 %s 提交失败: %s", audio_file.name, e) + fail_count += 1 - # 3. 执行推理(复用容器) - logging.info("正在执行推理...") - result = modal_pipeline.remote(payload) - - # 4. 写入结果文件到本地 + for audio_file, manifest, future in futures: + try: + result = getattr(future, "get")() download_outputs(manifest, result) - logging.info("✓ 文件 %s 处理完成", audio_file.name) success_count += 1 except Exception as e: logging.error("✗ 文件 %s 处理失败: %s", audio_file.name, e) fail_count += 1 - continue # 继续处理下一个文件 return success_count, fail_count def download_outputs( manifest: UploadManifest, - result: dict, + result: Dict, ) -> None: """从远程结果中提取文件内容并写入本地""" import base64 @@ -526,21 +598,37 @@ def download_outputs( created_files = result.get("created_files", {}) log_content = result.get("log_content") - # 获取原始文件名的 stem(不含扩展名) - original_stem = Path(manifest.original_filename).stem if manifest.original_filename else "todo" + # 获取原始相对路径,用于重建目录结构 + original_rel_path = manifest.original_rel_path + if original_rel_path: + original_stem = Path(original_rel_path).stem + else: + original_stem = ( + Path(manifest.original_filename).stem + if manifest.original_filename + else "todo" + ) for filename, content_b64 in created_files.items(): content = base64.b64decode(content_b64) - # 将 todo.xxx 替换为原始文件名 - if filename.startswith("todo."): - suffix = Path(filename).suffix - new_filename = original_stem + suffix + + # 根据原始文件名生成新文件名 + suffix = Path(filename).suffix + new_filename = original_stem + suffix + + # 使用 original_rel_path 重建目录结构 + if original_rel_path and original_rel_path.parent != Path("."): + # 如果有子目录结构,创建相应的子目录 + local_path = ( + manifest.local_output_dir / original_rel_path.parent / new_filename + ) else: - new_filename = filename + # 没有子目录,直接放到输出目录 + local_path = manifest.local_output_dir / new_filename - local_path = manifest.local_output_dir / new_filename local_path.parent.mkdir(parents=True, exist_ok=True) local_path.write_bytes(content) + logging.info("[写入] %s (%d bytes)", local_path, len(content)) logging.info("写入文件: %s (%d bytes)", local_path, len(content)) # 写入 log 文件 @@ -552,18 +640,20 @@ def download_outputs( logging.info("写入日志: %s", log_path) -def summarize(manifest: UploadManifest, result: dict) -> None: +def summarize(manifest: UploadManifest, result: Dict) -> None: logging.info("=== 运行完成 ===") logging.info("Session: %s", manifest.session_id) logging.info("源路径: %s", manifest.local_source) logging.info( "输出路径: %s", - manifest.local_output_dir if manifest.source_type == "directory" else manifest.local_source.parent, + manifest.local_output_dir + if manifest.source_type == "directory" + else manifest.local_source.parent, ) created_files = result.get("created_files", {}) if created_files: logging.info("新生成文件:") - for filename in created_files: + for filename in created_files.keys(): logging.info(" %s", filename) @@ -580,8 +670,10 @@ def parse_args() -> argparse.Namespace: def prompt_exit(enabled: bool) -> None: if not enabled: return - with contextlib.suppress(EOFError): + try: input("输入任意键退出...") + except EOFError: + pass def main() -> int: @@ -597,8 +689,13 @@ def main() -> int: if selection.input_path.is_dir(): # 文件夹模式:逐个处理文件,容器复用 - logging.info("检测到文件夹输入,将逐个处理 %d 个音频文件", len(scan_result.audio_files)) - success_count, fail_count = process_directory_files(volume, selection, scan_result.audio_files) + logging.info( + "检测到文件夹输入,将逐个处理 %d 个音频文件", + len(scan_result.audio_files), + ) + success_count, fail_count = process_directory_files( + volume, selection, scan_result.audio_files + ) logging.info("=" * 60) logging.info("=== 批量处理完成 ===") logging.info("成功: %d, 失败: %d", success_count, fail_count) @@ -606,12 +703,29 @@ def main() -> int: logging.info("✅ 请在上方输出路径查看字幕结果。") else: # 单文件模式:保持原有逻辑 - manifest = upload_single_file(volume, selection, selection.input_path) - payload = build_job_payload(selection, manifest) - result = run_remote_pipeline(volume, selection, manifest, payload) - download_outputs(manifest, result) - summarize(manifest, result) - logging.info("✅ 请在上方输出路径查看字幕结果。") + input_path = selection.input_path + output_dir = input_path.parent + + # 检查是否已有输出文件 + input_stem = input_path.stem + has_existing = False + for suffix in SUB_SUFFIXES: + output_file = output_dir / (input_stem + suffix) + if output_file.exists(): + has_existing = True + logging.info( + "[跳过] 文件 %s 已存在字幕文件,跳过处理", input_path.name + ) + break + + if not has_existing: + manifest = upload_single_file(volume, selection, input_path) + payload = build_job_payload(selection, manifest) + result = run_remote_pipeline(volume, selection, manifest, payload) + download_outputs(manifest, result) + summarize(manifest, result) + + logging.info("✅ 请在文件所在目录查看字幕结果。") except KeyboardInterrupt: logging.warning("用户中断,未执行任何远程操作。") exit_code = 1 @@ -627,9 +741,10 @@ def main() -> int: return exit_code -def _remote_pipeline(job: dict) -> dict: - import os +def _remote_pipeline(job: Dict) -> Dict: + import subprocess from pathlib import Path + import os # 强制重新加载 Volume,确保看到最新上传的文件 from modal import Volume @@ -637,7 +752,9 @@ def _remote_pipeline(job: dict) -> dict: volume = Volume.from_name("Faster_Whisper") volume.reload() - def run(cmd: Sequence[str], cwd: str | None = None, env: dict | None = None) -> None: + def run( + cmd: Sequence[str], cwd: Optional[str] = None, env: Optional[dict] = None + ) -> None: print(" ".join(cmd), flush=True) subprocess.run(cmd, check=True, cwd=cwd, env=env) @@ -695,7 +812,10 @@ def snapshot(path: str) -> set: files.add(str(f)) return files - before = {target["remote_dir"]: snapshot(target["remote_dir"]) for target in job["output_targets"]} + before = { + target["remote_dir"]: snapshot(target["remote_dir"]) + for target in job["output_targets"] + } output_dir = Path(job["remote_output_dir"]) output_dir.mkdir(parents=True, exist_ok=True) @@ -742,7 +862,9 @@ def snapshot(path: str) -> set: files = list(session_dir.iterdir()) file_names = [f.name for f in files] log(f"等待文件出现: {input_path} ({waited}s)") - log(f" 当前 {session_dir} 下有 {len(files)} 个文件/文件夹: {file_names}") + log( + f" 当前 {session_dir} 下有 {len(files)} 个文件/文件夹: {file_names}" + ) else: log(f"等待文件出现: {input_path} ({waited}s)") log(f" 目录不存在: {session_dir}") @@ -759,7 +881,7 @@ def snapshot(path: str) -> set: # 打印调试信息 sessions_dir = mount_root / "sessions" log(f"推理命令执行失败,错误码: {e.returncode}") - log("=== 调试信息 ===") + log(f"=== 调试信息 ===") # 统计 sessions 目录下的文件夹数量 if sessions_dir.exists(): @@ -773,14 +895,14 @@ def snapshot(path: str) -> set: input_dir = Path(input_path).parent log(f"待处理文件目录: {input_dir}") if input_dir.exists(): - log("目录内容:") + log(f"目录内容:") for item in input_dir.iterdir(): item_type = "目录" if item.is_dir() else "文件" log(f" [{item_type}] {item.name}") else: log(f"目录不存在: {input_dir}") - log("=== 调试信息结束 ===") + log(f"=== 调试信息结束 ===") raise def to_volume_path(path_str: str) -> str: @@ -794,7 +916,9 @@ def to_volume_path(path_str: str) -> str: remote_dir = target["remote_dir"] after = snapshot(remote_dir) prev = before.get(remote_dir, set()) - new_files = sorted(file for file in after - prev if Path(file).suffix.lower() in SUB_SUFFIXES) + new_files = sorted( + file for file in after - prev if Path(file).suffix.lower() in SUB_SUFFIXES + ) for file_path in new_files: file_path = Path(file_path) if file_path.exists():