diff --git a/contrib/download_depccg_model.py b/contrib/download_depccg_model.py index dd3004b0..e84a2110 100644 --- a/contrib/download_depccg_model.py +++ b/contrib/download_depccg_model.py @@ -2,7 +2,7 @@ from urllib.request import urlretrieve from depccg.instance_models import MODEL_DIRECTORY -URL = 'https://qnlp.cambridgequantum.com/models/tri_headfirst.tar.gz' +URL = 'https://qnlp.quantinuum.com/models/tri_headfirst.tar.gz' print('Please consider using Bobcat, the parser included with lambeq,\n' 'instead of depccg.') diff --git a/lambeq/backend/snake_removal.py b/lambeq/backend/snake_removal.py index 346802ea..d83905f3 100644 --- a/lambeq/backend/snake_removal.py +++ b/lambeq/backend/snake_removal.py @@ -30,7 +30,11 @@ class InterchangerError(Exception): """ This is raised when we try to interchange conected boxes. """ def __init__(self, box0: Box, box1: Box) -> None: - super().__init__(f'Boxes {box0} and {box1} do not commute.') + super().__init__(box0, box1) + + def __str__(self) -> str: + return f'Boxes {self.args[0]} and {self.args[1]} do not commute.' + def snake_removal(diagram: Diagram, left: bool = False) -> Iterator[Diagram]: diff --git a/lambeq/text2diagram/ccg_rule.py b/lambeq/text2diagram/ccg_rule.py index bb62eddb..55261b98 100644 --- a/lambeq/text2diagram/ccg_rule.py +++ b/lambeq/text2diagram/ccg_rule.py @@ -27,6 +27,7 @@ class CCGRuleUseError(Exception): """Error raised when a :py:class:`CCGRule` is applied incorrectly.""" def __init__(self, rule: CCGRule, message: str) -> None: + super().__init__(rule, message) self.rule = rule self.message = message diff --git a/lambeq/text2diagram/ccg_type.py b/lambeq/text2diagram/ccg_type.py index 1f6b35ef..bf894472 100644 --- a/lambeq/text2diagram/ccg_type.py +++ b/lambeq/text2diagram/ccg_type.py @@ -28,6 +28,7 @@ class CCGParseError(Exception): """Error when parsing a CCG type string.""" def __init__(self, cat: str, message: str) -> None: + super().__init__(cat, message) self.cat = cat self.message = message diff --git a/lambeq/text2diagram/ccgbank_parser.py b/lambeq/text2diagram/ccgbank_parser.py index e51f3623..f07336af 100644 --- a/lambeq/text2diagram/ccgbank_parser.py +++ b/lambeq/text2diagram/ccgbank_parser.py @@ -58,6 +58,7 @@ def __init__(self, sentence: str = '', message: str = '') -> None: else: self.sentence = '' self.message = sentence + super().__init__(sentence, message) def __str__(self) -> str: if self.sentence: diff --git a/lambeq/text2diagram/depccg_parser.py b/lambeq/text2diagram/depccg_parser.py index 3d844a7b..8d120c96 100644 --- a/lambeq/text2diagram/depccg_parser.py +++ b/lambeq/text2diagram/depccg_parser.py @@ -66,9 +66,10 @@ def _import_depccg() -> None: class DepCCGParseError(Exception): def __init__(self, sentence: str) -> None: + super().__init__(sentence) self.sentence = sentence - def __str__(self) -> str: # pragma: no cover + def __str__(self) -> str: return f'depccg failed to parse: "{self.sentence!r}".' diff --git a/lambeq/text2diagram/model_based_reader/bobcat_parser.py b/lambeq/text2diagram/model_based_reader/bobcat_parser.py index d4b5f43f..702b7dbd 100644 --- a/lambeq/text2diagram/model_based_reader/bobcat_parser.py +++ b/lambeq/text2diagram/model_based_reader/bobcat_parser.py @@ -48,6 +48,7 @@ class BobcatParseError(Exception): def __init__(self, sentence: str) -> None: + super().__init__(sentence) self.sentence = sentence def __str__(self) -> str: diff --git a/lambeq/text2diagram/model_based_reader/model_downloader.py b/lambeq/text2diagram/model_based_reader/model_downloader.py index 44963798..8d1a24b3 100644 --- a/lambeq/text2diagram/model_based_reader/model_downloader.py +++ b/lambeq/text2diagram/model_based_reader/model_downloader.py @@ -31,7 +31,7 @@ from lambeq.core.globals import VerbosityLevel from lambeq.typing import StrPathT -MODELS_URL = 'https://qnlp.cambridgequantum.com/models' +MODELS_URL = 'https://qnlp.quantinuum.com/models' MODELS = {'bobcat', 'oncilla'} VERSION_FNAME = 'version.txt' CHECKSUM_FNAME = 'model_checksum.sha256' @@ -42,6 +42,7 @@ class ModelDownloaderError(Exception): def __init__(self, error_msg: str) -> None: + super().__init__(error_msg) self.error_msg = error_msg def __str__(self) -> str: @@ -87,7 +88,10 @@ def __init__(self, def get_url(self) -> str: """Get URL for the latest version of specified model.""" - return f'{MODELS_URL}/{self.model}/latest' + model_path = self.model + if self.model == 'bobcat': + model_path = 'bert' + return f'{MODELS_URL}/{model_path}/latest' def get_dir(self, cache_dir: StrPathT | None = None) -> Path: @@ -125,6 +129,13 @@ def get_local_model_version(self) -> str | None: try: with open(self.model_dir / VERSION_FNAME) as f: local_version = f.read().strip() + except FileNotFoundError: + # Fallback: Check for nested directory (e.g. bobcat/bobcat) + try: + with open(self.model_dir / self.model / VERSION_FNAME) as f: + local_version = f.read().strip() + except Exception: + local_version = None except Exception: local_version = None @@ -137,6 +148,14 @@ def download_model(self, and then extract the model to `model_dir`""" if self.remote_version is None: + if self.get_local_model_version() is not None: + print('Failed to retrieve remote model version. ' + 'Using local model.', file=sys.stderr) + # Check if model is in nested directory and update path + nested_dir = self.model_dir / self.model + if nested_dir.exists() and (nested_dir / 'config.json').exists(): + self.model_dir = nested_dir + return raise self.version_retrieval_error expected_checksum = self._get_remote_checksum() diff --git a/lambeq/text2diagram/model_based_reader/oncilla_parser.py b/lambeq/text2diagram/model_based_reader/oncilla_parser.py index f15590f3..a516052d 100644 --- a/lambeq/text2diagram/model_based_reader/oncilla_parser.py +++ b/lambeq/text2diagram/model_based_reader/oncilla_parser.py @@ -45,6 +45,7 @@ class OncillaParseError(Exception): def __init__(self, sentence: str, reason: str = '') -> None: + super().__init__(sentence, reason) self.sentence = sentence self.reason = reason diff --git a/lambeq/text2diagram/pregroup_tree.py b/lambeq/text2diagram/pregroup_tree.py index 7ed6a9c9..b8780ab2 100644 --- a/lambeq/text2diagram/pregroup_tree.py +++ b/lambeq/text2diagram/pregroup_tree.py @@ -24,6 +24,7 @@ class PregroupTreeNodeError(Exception): def __init__(self, sentence: str) -> None: + super().__init__(sentence) self.sentence = sentence def __str__(self) -> str: diff --git a/lambeq/text2diagram/web_parser.py b/lambeq/text2diagram/web_parser.py index ee152464..e2de56c1 100644 --- a/lambeq/text2diagram/web_parser.py +++ b/lambeq/text2diagram/web_parser.py @@ -40,10 +40,11 @@ class WebParseError(OSError): def __init__(self, sentence: str) -> None: + super().__init__(sentence) self.sentence = sentence def __str__(self) -> str: - return (f'Web parser could not parse {repr(self.sentence)}') + return f'Web parser could not parse {repr(self.sentence)}' class WebParser(CCGParser): diff --git a/lambeq/training/nelder_mead_optimizer.py b/lambeq/training/nelder_mead_optimizer.py index 6b2377c0..c2f78eb6 100644 --- a/lambeq/training/nelder_mead_optimizer.py +++ b/lambeq/training/nelder_mead_optimizer.py @@ -259,7 +259,7 @@ def objective(self, x: Iterable[Any], y: ArrayLike, w: ArrayLike) -> float: raise ValueError( 'Objective function must return a scalar' ) from e - return result # type: ignore[return-value] + return result def backward(self, batch: tuple[Iterable[Any], np.ndarray]) -> float: """Calculate the gradients of the loss function. diff --git a/setup.cfg b/setup.cfg index edf6f957..c6a8c228 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,6 +75,7 @@ extras = discopy >= 1.1.7 jax < 0.6.0 jaxlib < 0.6.0 + autoray < 0.6.1 pennylane >= 0.29.1, < 0.37.0 pennylane-honeywell pennylane-qiskit >= 0.29.1, < 0.37.0