Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion contrib/download_depccg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from urllib.request import urlretrieve
from depccg.instance_models import MODEL_DIRECTORY

URL = 'https://qnlp.cambridgequantum.com/models/tri_headfirst.tar.gz'
URL = 'https://qnlp.quantinuum.com/models/tri_headfirst.tar.gz'

print('Please consider using Bobcat, the parser included with lambeq,\n'
'instead of depccg.')
Expand Down
6 changes: 5 additions & 1 deletion lambeq/backend/snake_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
class InterchangerError(Exception):
""" This is raised when we try to interchange conected boxes. """
def __init__(self, box0: Box, box1: Box) -> None:
super().__init__(f'Boxes {box0} and {box1} do not commute.')
super().__init__(box0, box1)

def __str__(self) -> str:
return f'Boxes {self.args[0]} and {self.args[1]} do not commute.'



def snake_removal(diagram: Diagram, left: bool = False) -> Iterator[Diagram]:
Expand Down
1 change: 1 addition & 0 deletions lambeq/text2diagram/ccg_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class CCGRuleUseError(Exception):
"""Error raised when a :py:class:`CCGRule` is applied incorrectly."""
def __init__(self, rule: CCGRule, message: str) -> None:
super().__init__(rule, message)
self.rule = rule
self.message = message

Expand Down
1 change: 1 addition & 0 deletions lambeq/text2diagram/ccg_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class CCGParseError(Exception):
"""Error when parsing a CCG type string."""

def __init__(self, cat: str, message: str) -> None:
super().__init__(cat, message)
self.cat = cat
self.message = message

Expand Down
1 change: 1 addition & 0 deletions lambeq/text2diagram/ccgbank_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, sentence: str = '', message: str = '') -> None:
else:
self.sentence = ''
self.message = sentence
super().__init__(sentence, message)

def __str__(self) -> str:
if self.sentence:
Expand Down
3 changes: 2 additions & 1 deletion lambeq/text2diagram/depccg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def _import_depccg() -> None:

class DepCCGParseError(Exception):
def __init__(self, sentence: str) -> None:
super().__init__(sentence)
self.sentence = sentence

def __str__(self) -> str: # pragma: no cover
def __str__(self) -> str:
return f'depccg failed to parse: "{self.sentence!r}".'


Expand Down
1 change: 1 addition & 0 deletions lambeq/text2diagram/model_based_reader/bobcat_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

class BobcatParseError(Exception):
def __init__(self, sentence: str) -> None:
super().__init__(sentence)
self.sentence = sentence

def __str__(self) -> str:
Expand Down
23 changes: 21 additions & 2 deletions lambeq/text2diagram/model_based_reader/model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from lambeq.core.globals import VerbosityLevel
from lambeq.typing import StrPathT

MODELS_URL = 'https://qnlp.cambridgequantum.com/models'
MODELS_URL = 'https://qnlp.quantinuum.com/models'
MODELS = {'bobcat', 'oncilla'}
VERSION_FNAME = 'version.txt'
CHECKSUM_FNAME = 'model_checksum.sha256'
Expand All @@ -42,6 +42,7 @@

class ModelDownloaderError(Exception):
def __init__(self, error_msg: str) -> None:
super().__init__(error_msg)
self.error_msg = error_msg

def __str__(self) -> str:
Expand Down Expand Up @@ -87,7 +88,10 @@ def __init__(self,
def get_url(self) -> str:
"""Get URL for the latest version of specified model."""

return f'{MODELS_URL}/{self.model}/latest'
model_path = self.model
if self.model == 'bobcat':
model_path = 'bert'
return f'{MODELS_URL}/{model_path}/latest'

def get_dir(self,
cache_dir: StrPathT | None = None) -> Path:
Expand Down Expand Up @@ -125,6 +129,13 @@ def get_local_model_version(self) -> str | None:
try:
with open(self.model_dir / VERSION_FNAME) as f:
local_version = f.read().strip()
except FileNotFoundError:
# Fallback: Check for nested directory (e.g. bobcat/bobcat)
try:
with open(self.model_dir / self.model / VERSION_FNAME) as f:
local_version = f.read().strip()
except Exception:
local_version = None
except Exception:
local_version = None

Expand All @@ -137,6 +148,14 @@ def download_model(self,
and then extract the model to `model_dir`"""

if self.remote_version is None:
if self.get_local_model_version() is not None:
print('Failed to retrieve remote model version. '
'Using local model.', file=sys.stderr)
# Check if model is in nested directory and update path
nested_dir = self.model_dir / self.model
if nested_dir.exists() and (nested_dir / 'config.json').exists():
self.model_dir = nested_dir
return
raise self.version_retrieval_error

expected_checksum = self._get_remote_checksum()
Expand Down
1 change: 1 addition & 0 deletions lambeq/text2diagram/model_based_reader/oncilla_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

class OncillaParseError(Exception):
def __init__(self, sentence: str, reason: str = '') -> None:
super().__init__(sentence, reason)
self.sentence = sentence
self.reason = reason

Expand Down
1 change: 1 addition & 0 deletions lambeq/text2diagram/pregroup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

class PregroupTreeNodeError(Exception):
def __init__(self, sentence: str) -> None:
super().__init__(sentence)
self.sentence = sentence

def __str__(self) -> str:
Expand Down
3 changes: 2 additions & 1 deletion lambeq/text2diagram/web_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@

class WebParseError(OSError):
def __init__(self, sentence: str) -> None:
super().__init__(sentence)
self.sentence = sentence

def __str__(self) -> str:
return (f'Web parser could not parse {repr(self.sentence)}')
return f'Web parser could not parse {repr(self.sentence)}'


class WebParser(CCGParser):
Expand Down
2 changes: 1 addition & 1 deletion lambeq/training/nelder_mead_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def objective(self, x: Iterable[Any], y: ArrayLike, w: ArrayLike) -> float:
raise ValueError(
'Objective function must return a scalar'
) from e
return result # type: ignore[return-value]
return result

def backward(self, batch: tuple[Iterable[Any], np.ndarray]) -> float:
"""Calculate the gradients of the loss function.
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ extras =
discopy >= 1.1.7
jax < 0.6.0
jaxlib < 0.6.0
autoray < 0.6.1
pennylane >= 0.29.1, < 0.37.0
pennylane-honeywell
pennylane-qiskit >= 0.29.1, < 0.37.0
Expand Down