diff --git a/lambeq/core/utils.py b/lambeq/core/utils.py index 7163b6c4..0e32d405 100644 --- a/lambeq/core/utils.py +++ b/lambeq/core/utils.py @@ -14,9 +14,16 @@ from __future__ import annotations +import logging from math import floor import pickle -from typing import Any, List, Union +from typing import Any, List, TYPE_CHECKING, Union + +import spacy + + +if TYPE_CHECKING: + import spacy.cli TokenisedSentenceType = List[str] @@ -94,3 +101,16 @@ def normalise_duration(duration_secs: float | None) -> str: def fast_deepcopy(obj: Any) -> Any: """Fast deepcopy (faster than `copy.deepcopy`).""" return pickle.loads(pickle.dumps(obj)) + + +def get_spacy_tokeniser( + model: str = 'en_core_web_sm' +) -> spacy.language.Language: + try: + return spacy.load(model) + except OSError: + logger = logging.getLogger(__name__) + logger.warning('Downloading SpaCy tokeniser. ' + 'This action only has to happen once.') + spacy.cli.download(model) + return spacy.load(model) diff --git a/lambeq/experimental/discocirc/coref_resolver.py b/lambeq/experimental/discocirc/coref_resolver.py index 398e84aa..7d4b50c5 100644 --- a/lambeq/experimental/discocirc/coref_resolver.py +++ b/lambeq/experimental/discocirc/coref_resolver.py @@ -14,10 +14,17 @@ from abc import ABC, abstractmethod import re +from typing import TYPE_CHECKING import spacy import torch +from lambeq.core.utils import get_spacy_tokeniser + + +if TYPE_CHECKING: + import spacy.cli + SPACY_NOUN_POS = {'NOUN', 'PROPN', 'PRON'} TokenisedTextT = list[list[str]] @@ -103,7 +110,7 @@ def __init__( from maverick import Maverick # Create basic tokenisation pipeline, for POS - self.nlp = spacy.load('en_core_web_sm') + self.nlp = get_spacy_tokeniser() self.model = Maverick(hf_name_or_path=hf_name_or_path, device=device) @@ -165,7 +172,7 @@ class SpacyCoreferenceResolver(CoreferenceResolver): def __init__(self): # Create basic tokenisation pipeline, for POS - self.nlp = spacy.load('en_core_web_sm') + self.nlp = get_spacy_tokeniser() # Add coreference resolver pipe stage try: @@ -174,9 +181,9 @@ def __init__(self): except OSError as ose: raise UserWarning( '`SpacyCoreferenceResolver` requires the experimental' - ' `en_coreferenc_web_trf` model.' + ' `en_coreference_web_trf` model.' ' See https://github.com/explosion/spacy-experimental/releases/tag/v0.6.1' # noqa: W505, E501 - ' for installation instructions. For a stable installation, ' + ' for installation instructions. For a stable installation,' ' please use Python 3.10.' ) from ose diff --git a/lambeq/tokeniser/spacy_tokeniser.py b/lambeq/tokeniser/spacy_tokeniser.py index 68999e35..867e6f7f 100644 --- a/lambeq/tokeniser/spacy_tokeniser.py +++ b/lambeq/tokeniser/spacy_tokeniser.py @@ -24,20 +24,16 @@ __all__ = ['SpacyTokeniser'] from collections.abc import Iterable -import logging from typing import TYPE_CHECKING -from lambeq.tokeniser import Tokeniser +import spacy +import spacy.lang.en -if TYPE_CHECKING: - import spacy - import spacy.cli +from lambeq.core.utils import get_spacy_tokeniser +from lambeq.tokeniser import Tokeniser -def _import_spacy() -> None: - global spacy - import spacy - import spacy.lang.en +if TYPE_CHECKING: import spacy.cli @@ -45,15 +41,7 @@ class SpacyTokeniser(Tokeniser): """Tokeniser class based on SpaCy.""" def __init__(self) -> None: - _import_spacy() - try: - self.tokeniser = spacy.load('en_core_web_sm') - except OSError: - logger = logging.getLogger(__name__) - logger.warning('Downloading SpaCy tokeniser. ' - 'This action only has to happen once.') - spacy.cli.download('en_core_web_sm') - self.tokeniser = spacy.load('en_core_web_sm') + self.tokeniser = get_spacy_tokeniser() self.spacy_nlp = spacy.lang.en.English() self.spacy_nlp.add_pipe('sentencizer') diff --git a/setup.cfg b/setup.cfg index 2fca7dd0..edf6f957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -91,8 +91,7 @@ test = pytest experimental = - en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl - maverick-coref == 1.0.3 + maverick-coref >= 1.0.3 [options.entry_points] console_scripts =