diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 79d5f90..9265163 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,7 +22,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip3 install -r requirements.txt -r requirements-dev.txt + run: pip3 install -r requirements.txt -r requirements-dev.txt -r requirements-metrics.txt - name: Lint run: make lint diff --git a/README.md b/README.md index 0961a5d..be4c0d9 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ To install from PyPI: ```bash pip install speechmatics-python ``` + +To use the sm-metrics tool for diarization features and speaker identification metrics, install with the optional dependencies: +```bash +pip install speechmatics-python[metrics] +``` To install from source: ```bash git clone https://github.com/speechmatics/speechmatics-python diff --git a/asr_metrics/cli.py b/asr_metrics/cli.py index a4ea25f..5d089d2 100644 --- a/asr_metrics/cli.py +++ b/asr_metrics/cli.py @@ -1,34 +1,74 @@ """Entrypoint for SM metrics""" import argparse +import sys -import asr_metrics.diarization.sm_diarization_metrics.cookbook as diarization_metrics -import asr_metrics.wer.__main__ as wer_metrics +try: + import asr_metrics.wer.__main__ as wer_metrics + + WER_AVAILABLE = True +except ImportError: + WER_AVAILABLE = False + +try: + import asr_metrics.diarization.sm_diarization_metrics.cookbook as diarization_metrics + + DIARIZATION_AVAILABLE = True +except ImportError: + DIARIZATION_AVAILABLE = False def main(): - parser = argparse.ArgumentParser(description="Your CLI description") + parser = argparse.ArgumentParser( + description="Speechmatics metrics tool for WER and diarization" + ) # Create subparsers subparsers = parser.add_subparsers( dest="mode", help="Metrics mode. Choose from 'wer' or 'diarization'" ) - subparsers.required = True # Make sure a subparser id always provided + subparsers.required = True # Make sure a subparser is always provided - wer_parser = subparsers.add_parser("wer", help="Entrypoint for WER metrics") - wer_metrics.get_wer_args(wer_parser) + if WER_AVAILABLE: + wer_parser = subparsers.add_parser("wer", help="Entrypoint for WER metrics") + wer_metrics.get_wer_args(wer_parser) + else: + wer_parser = subparsers.add_parser( + "wer", help="Entrypoint for WER metrics (requires additional dependencies)" + ) - diarization_parser = subparsers.add_parser( - "diarization", help="Entrypoint for diarization metrics" - ) - diarization_metrics.get_diarization_args(diarization_parser) + if DIARIZATION_AVAILABLE: + diarization_parser = subparsers.add_parser( + "diarization", help="Entrypoint for diarization metrics" + ) + diarization_metrics.get_diarization_args(diarization_parser) + else: + diarization_parser = subparsers.add_parser( + "diarization", + help="Entrypoint for diarization metrics (requires pyannote dependencies)", + ) + diarization_parser.add_argument( + "--help-install", + action="store_true", + help="Show instructions for installing diarization dependencies", + ) args = parser.parse_args() if args.mode == "wer": - wer_metrics.main(args) + if WER_AVAILABLE: + wer_metrics.main(args) + else: + print("Error: WER metrics require additional dependencies.") + print("Please install them with: pip install speechmatics-python[metrics]") + sys.exit(1) elif args.mode == "diarization": - diarization_metrics.main(args) + if DIARIZATION_AVAILABLE: + diarization_metrics.main(args) + else: + print("Error: Diarization metrics require additional dependencies.") + print("Please install them with: pip install speechmatics-python[metrics]") + sys.exit(1) else: print("Unsupported mode. Please use 'wer' or 'diarization'") diff --git a/requirements-metrics.txt b/requirements-metrics.txt new file mode 100644 index 0000000..e35140f --- /dev/null +++ b/requirements-metrics.txt @@ -0,0 +1,8 @@ +docopt +jiwer +more-itertools +pandas +pyannote.core +pyannote.database +regex +tabulate>=0.8.9 diff --git a/requirements.txt b/requirements.txt index 875ada5..f19dcaf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,3 @@ httpx[http2]~=0.23 polling2~=0.5 toml~=0.10.2 tenacity~=8.2.3 -jiwer -regex -more-itertools -pyannote.core -pyannote.database -docopt -tabulate>=0.8.9 diff --git a/setup.py b/setup.py index f25c5bf..e4f4ca8 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,9 @@ def get_version(fname): long_description_content_type="text/markdown", install_requires=read_list("requirements.txt"), tests_require=read_list("requirements-dev.txt"), + extras_require={ + "metrics": read_list("requirements-metrics.txt"), + }, entry_points={ "console_scripts": [ "speechmatics = speechmatics.cli:main",