Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion lambeq/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@

from __future__ import annotations

import logging
from math import floor
import pickle
from typing import Any, List, Union
from typing import Any, List, TYPE_CHECKING, Union

import spacy


if TYPE_CHECKING:
import spacy.cli


TokenisedSentenceType = List[str]
Expand Down Expand Up @@ -94,3 +101,16 @@ def normalise_duration(duration_secs: float | None) -> str:
def fast_deepcopy(obj: Any) -> Any:
"""Fast deepcopy (faster than `copy.deepcopy`)."""
return pickle.loads(pickle.dumps(obj))


def get_spacy_tokeniser(
model: str = 'en_core_web_sm'
) -> spacy.language.Language:
try:
return spacy.load(model)
except OSError:
logger = logging.getLogger(__name__)
logger.warning('Downloading SpaCy tokeniser. '
'This action only has to happen once.')
spacy.cli.download(model)
return spacy.load(model)
15 changes: 11 additions & 4 deletions lambeq/experimental/discocirc/coref_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@

from abc import ABC, abstractmethod
import re
from typing import TYPE_CHECKING

import spacy
import torch

from lambeq.core.utils import get_spacy_tokeniser


if TYPE_CHECKING:
import spacy.cli


SPACY_NOUN_POS = {'NOUN', 'PROPN', 'PRON'}
TokenisedTextT = list[list[str]]
Expand Down Expand Up @@ -103,7 +110,7 @@ def __init__(
from maverick import Maverick

# Create basic tokenisation pipeline, for POS
self.nlp = spacy.load('en_core_web_sm')
self.nlp = get_spacy_tokeniser()
self.model = Maverick(hf_name_or_path=hf_name_or_path,
device=device)

Expand Down Expand Up @@ -165,7 +172,7 @@ class SpacyCoreferenceResolver(CoreferenceResolver):

def __init__(self):
# Create basic tokenisation pipeline, for POS
self.nlp = spacy.load('en_core_web_sm')
self.nlp = get_spacy_tokeniser()

# Add coreference resolver pipe stage
try:
Expand All @@ -174,9 +181,9 @@ def __init__(self):
except OSError as ose:
raise UserWarning(
'`SpacyCoreferenceResolver` requires the experimental'
' `en_coreferenc_web_trf` model.'
' `en_coreference_web_trf` model.'
' See https://github.com/explosion/spacy-experimental/releases/tag/v0.6.1' # noqa: W505, E501
' for installation instructions. For a stable installation, '
' for installation instructions. For a stable installation,'
' please use Python 3.10.'
) from ose

Expand Down
24 changes: 6 additions & 18 deletions lambeq/tokeniser/spacy_tokeniser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,24 @@
__all__ = ['SpacyTokeniser']

from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING

from lambeq.tokeniser import Tokeniser
import spacy
import spacy.lang.en

if TYPE_CHECKING:
import spacy
import spacy.cli
from lambeq.core.utils import get_spacy_tokeniser
from lambeq.tokeniser import Tokeniser


def _import_spacy() -> None:
global spacy
import spacy
import spacy.lang.en
if TYPE_CHECKING:
import spacy.cli


class SpacyTokeniser(Tokeniser):
"""Tokeniser class based on SpaCy."""

def __init__(self) -> None:
_import_spacy()
try:
self.tokeniser = spacy.load('en_core_web_sm')
except OSError:
logger = logging.getLogger(__name__)
logger.warning('Downloading SpaCy tokeniser. '
'This action only has to happen once.')
spacy.cli.download('en_core_web_sm')
self.tokeniser = spacy.load('en_core_web_sm')
self.tokeniser = get_spacy_tokeniser()
self.spacy_nlp = spacy.lang.en.English()
self.spacy_nlp.add_pipe('sentencizer')

Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ test =
pytest

experimental =
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl
maverick-coref == 1.0.3
maverick-coref >= 1.0.3

[options.entry_points]
console_scripts =
Expand Down
Loading