Skip to content

Commit 7c12cae

Browse files
bug(medcat): CU-869ckx6dr Allow for better supervised training (#374)
* CU-869ckx6dr: Add extra test to trainer to make sure it tests on multiple projects * CU-869ckx6dr: Add new method for reuse of entities when getting based on tokens * CU-869ckx6dr: Add simple test for entity persitance in document * CU-869ckx6dr: Small addition to test * CU-869ckx6dr: Prepare document with appropriate entities at training time * CU-869ckx6dr: Update tests to work with new setup * CU-869ckx6dr: Add a new test for entities in add_and_train_concept. * CU-869ckx6dr: Add deprecation arning to old / unused entity_from_tokens method in pipe * CU-869ckx6dr: Add deprecation warning to old / unused entity_from_tokens method in tokenizers * CU-869ckx6dr: Deprecate unused method on a protocol level as well * CU-869ckx6dr: Fix linting issue * CU-869ckx6dr: Fix minor issues in test-time supervised triaining data * CU-869ckx6dr: Add enw test for order of training examples * CU-869ckx6dr: Minor changes to trainer tests * CU-869ckx6dr: Allow a little longer for the relcat tutorial to run --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b43c22e commit 7c12cae

10 files changed

Lines changed: 235 additions & 12 deletions

File tree

.github/workflows/medcat-v2-tutorials_main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,4 @@ jobs:
8383
- name: Smoke test tutorial
8484
run: |
8585
pytest --capture=no --collect-only --nbmake ${{ matrix.part }}
86-
pytest --capture=no --nbmake -n=auto --nbmake-kernel=smoketests --nbmake-timeout=1800 ${{ matrix.part }}
86+
pytest --capture=no --nbmake -n=auto --nbmake-kernel=smoketests --nbmake-timeout=2400 ${{ matrix.part }}

medcat-v2/medcat/pipeline/pipeline.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Iterable, Union
22
import logging
33
import os
4+
import warnings
45

56
from medcat.utils.defaults import COMPONENTS_FOLDER
67
from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer
@@ -43,8 +44,19 @@ def create_entity(self, doc: MutableDocument,
4344
doc, token_start_index, token_end_index, label)
4445

4546
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
47+
warnings.warn(
48+
"The `medcat.pipeline.pipeline.entity_from_tokens` method is"
49+
"depreacated and subject to removal in a future release. Please "
50+
"use `medcat.pipeline.pipeline.entity_from_tokens_in_doc` instead.",
51+
DeprecationWarning,
52+
stacklevel=2
53+
)
4654
return self.tokenizer.entity_from_tokens(tokens)
4755

56+
def entity_from_tokens_in_doc(
57+
self, tokens: list[MutableToken], doc: MutableDocument) -> MutableEntity:
58+
return self.tokenizer.entity_from_tokens_in_doc(tokens, doc)
59+
4860
def __call__(self, text: str) -> MutableDocument:
4961
doc = self.tokenizer(text)
5062
for comp in self.components:
@@ -342,6 +354,23 @@ def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
342354
"""
343355
return self._tokenizer.entity_from_tokens(tokens)
344356

357+
def entity_from_tokens_in_doc(self, tokens: list[MutableToken],
358+
doc: MutableDocument) -> MutableEntity:
359+
"""Get the entity from the list of tokens in a document.
360+
361+
This effectively turns a list of (consecutive) documents
362+
into an entity. But it is also designed to reuse existing
363+
instances on the document instead of creating new ones.
364+
365+
Args:
366+
tokens (list[MutableToken]): The tokens to use.
367+
doc (MutableDocument): The document for these tokens.
368+
369+
Returns:
370+
MutableEntity: The resulting entity.
371+
"""
372+
return self._tokenizer.entity_from_tokens_in_doc(tokens, doc)
373+
345374
def get_component(self, ctype: CoreComponentType) -> CoreComponent:
346375
"""Get the core component by the component type.
347376

medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
from typing import cast, Optional, Iterator, overload, Union, Any, Type
33
from collections import defaultdict
4+
import warnings
45

56
from medcat.tokenizing.tokens import (
67
BaseToken, BaseEntity, BaseDocument,
@@ -340,13 +341,38 @@ def create_entity(self, doc: MutableDocument,
340341
# return Entity(span)
341342

342343
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
344+
warnings.warn(
345+
"The `medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens` method is"
346+
"depreacated and subject to removal in a future release. Please use "
347+
"`medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens_in_doc` "
348+
"instead.",
349+
DeprecationWarning,
350+
stacklevel=2
351+
)
343352
if not tokens:
344353
raise ValueError("Need at least one token for an entity")
345354
doc = cast(Token, tokens[0])._doc
346355
start_index = doc._tokens.index(tokens[0])
347356
end_index = doc._tokens.index(tokens[-1])
348357
return _entity_from_tokens(doc, tokens, start_index, end_index)
349358

359+
def _get_existing_entity(self, tokens: list[MutableToken],
360+
doc: MutableDocument) -> Optional[MutableEntity]:
361+
if not tokens:
362+
return None
363+
for ent in doc.ner_ents + doc.linked_ents:
364+
if (ent.base.start_index == tokens[0].base.index and
365+
ent.base.end_index == tokens[-1].base.index):
366+
return ent
367+
return None
368+
369+
def entity_from_tokens_in_doc(self, tokens: list[MutableToken],
370+
doc: MutableDocument) -> MutableEntity:
371+
existing_ent = self._get_existing_entity(tokens, doc)
372+
if existing_ent:
373+
return existing_ent
374+
return self.entity_from_tokens(tokens)
375+
350376
def _get_tokens_matches(self, text: str) -> list[re.Match[str]]:
351377
tokens = self.REGEX.finditer(text)
352378
return list(tokens)

medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import shutil
55
import logging
6+
import warnings
67

78
import spacy
89
from spacy.tokens import Span
@@ -77,13 +78,38 @@ def create_entity(self, doc: MutableDocument,
7778
return Entity(span)
7879

7980
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
81+
warnings.warn(
82+
"The `medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens` method is"
83+
"depreacated and subject to removal in a future release. Please use "
84+
"`medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens_in_doc` "
85+
"instead.",
86+
DeprecationWarning,
87+
stacklevel=2
88+
)
8089
if not tokens:
8190
raise ValueError("Need at least one token for an entity")
8291
spacy_tokens = cast(list[Token], tokens)
8392
span = Span(spacy_tokens[0]._delegate.doc, spacy_tokens[0].index,
8493
spacy_tokens[-1].index + 1)
8594
return Entity(span)
8695

96+
def _get_existing_entity(self, tokens: list[MutableToken],
97+
doc: MutableDocument) -> Optional[MutableEntity]:
98+
if not tokens:
99+
return None
100+
for ent in doc.ner_ents + doc.linked_ents:
101+
if (ent.base.start_index == tokens[0].base.index and
102+
ent.base.end_index == tokens[-1].base.index):
103+
return ent
104+
return None
105+
106+
def entity_from_tokens_in_doc(self, tokens: list[MutableToken],
107+
doc: MutableDocument) -> MutableEntity:
108+
existing_ent = self._get_existing_entity(tokens, doc)
109+
if existing_ent:
110+
return existing_ent
111+
return self.entity_from_tokens(tokens)
112+
87113
def __call__(self, text: str) -> MutableDocument:
88114
if self._avoid_pipe:
89115
doc = Document(self._nlp.make_doc(text))

medcat-v2/medcat/tokenizing/tokenizers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,22 @@ def create_entity(self, doc: MutableDocument,
3434
pass
3535

3636
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
37-
"""Get an entity from the list of tokens.
37+
"""Deprecated: use entity_from_tokens_in_doc instead."""
38+
pass
39+
40+
def entity_from_tokens_in_doc(self, tokens: list[MutableToken],
41+
doc: MutableDocument) -> MutableEntity:
42+
"""Get an entity from the list of tokens in the specified document.
43+
44+
This method is designed to reuse entities where possible.
3845
3946
Args:
4047
tokens (list[MutableToken]): List of tokens.
48+
doc (MutableDocument): The document for these tokens.
4149
4250
Returns:
4351
MutableEntity: The resulting entity.
4452
"""
45-
pass
4653

4754
def __call__(self, text: str) -> MutableDocument:
4855
pass

medcat-v2/medcat/trainer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from medcat.utils.data_utils import make_mc_train_test, get_false_positives
1212
from medcat.utils.filters import project_filters
1313
from medcat.data.mctexport import (
14-
MedCATTrainerExport, MedCATTrainerExportProject,
14+
MedCATTrainerExport, MedCATTrainerExportAnnotation, MedCATTrainerExportProject,
1515
MedCATTrainerExportDocument, count_all_annotations, iter_anns)
1616
from medcat.preprocessors.cleaners import prepare_name, NameDescriptor
1717
from medcat.components.types import CoreComponentType, TrainableComponent
@@ -397,6 +397,20 @@ def _train_supervised_for_project(self,
397397
docs, current_document, train_from_false_positives,
398398
devalue_others)
399399

400+
def _prepare_doc_with_anns(
401+
self, doc: MutableDocument,
402+
anns: list[MedCATTrainerExportAnnotation]) -> None:
403+
ents = []
404+
for ann in anns:
405+
tkns = doc.get_tokens(ann['start'], ann['end'])
406+
ents.append(self._pipeline.entity_from_tokens_in_doc(tkns, doc))
407+
# set NER ents
408+
doc.ner_ents.clear()
409+
doc.ner_ents.extend(ents)
410+
# duplicate for linked as well, but in a a separate list
411+
doc.linked_ents.clear()
412+
doc.linked_ents.extend(ents)
413+
400414
def _train_supervised_for_project2(self,
401415
docs: list[MedCATTrainerExportDocument],
402416
current_document: int,
@@ -412,17 +426,17 @@ def _train_supervised_for_project2(self,
412426
with temp_changed_config(self.config.components.linking,
413427
'train', False):
414428
mut_doc = self.caller(doc['text'])
429+
self._prepare_doc_with_anns(mut_doc, doc['annotations'])
415430

416431
# Compatibility with old output where annotations are a list
417-
for ann in doc['annotations']:
432+
for ann, mut_entity in zip(doc['annotations'], mut_doc.linked_ents):
418433
if ann.get('killed', False):
419434
continue
420435
logger.info(" Annotation %s (%s) [%d:%d]",
421436
ann['value'], ann['cui'], ann['start'], ann['end'])
422437
cui = ann['cui']
423438
start = ann['start']
424439
end = ann['end']
425-
mut_entity = mut_doc.get_tokens(start, end)
426440
if not mut_entity:
427441
logger.warning(
428442
"When looking for CUI '%s' (value '%s') [%d...%d] "

medcat-v2/tests/resources/supervised_mct_export.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
{
5959
"cui": "C04",
6060
"start": 81,
61-
"end": 87,
61+
"end": 88,
6262
"value": "fittest"
6363
}
6464
],
@@ -125,7 +125,7 @@
125125
"id": "ID-3",
126126
"last_modified": "2024-08-21",
127127
"name": "Doc#4",
128-
"text": "The RHS male is healthy as considered by all available tests. There are no indications that the patient is not fittest."
128+
"text": "The RHS male is healthy as considered by all available tests. There are no indications that the patient is not fittest"
129129
}
130130
],
131131
"id": "Project#0",

medcat-v2/tests/test_cat.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from contextlib import contextmanager
88

99
from medcat import cat
10+
from medcat.data.mctexport import count_all_annotations, iter_anns
1011
from medcat.data.model_card import ModelCard
1112
from medcat.vocab import Vocab
1213
from medcat.config import Config
@@ -576,7 +577,7 @@ class CATSupTrainingTests(CATUnsupTrainingTests):
576577
os.path.dirname(__file__), 'resources', 'supervised_mct_export.json'
577578
)
578579
# NOTE: should remain consistent unless we change the model or data
579-
EXPECTED_HASH = "7bfe01e8e36eb07d"
580+
EXPECTED_HASH = "9c299628c9e6c220"
580581

581582
@classmethod
582583
def _get_cui_counts(cls) -> dict[str, int]:
@@ -620,6 +621,21 @@ def test_clearing_training_works(self):
620621
self.assertEqual(self.cat.config.meta.unsup_trained, [])
621622
self.assertEqual(self.cat.config.meta.sup_trained, [])
622623

624+
def test_training_happens_in_correct_order(self):
625+
with captured_state_cdb(self.cat.cdb):
626+
with unittest.mock.patch.object(
627+
self.cat.trainer, "add_and_train_concept") as mock_add_and_train_concept:
628+
self._perform_training()
629+
mct_export = self._get_data()
630+
called_ents = [
631+
args.kwargs['mut_entity'] for args in mock_add_and_train_concept.call_args_list
632+
]
633+
self.assertEqual(len(called_ents), count_all_annotations(mct_export))
634+
for (_, _, ann), ent in zip(iter_anns(mct_export), called_ents):
635+
with self.subTest(f"Ann: {ann} vs Ent: {ent}"):
636+
self.assertEqual(ann['start'], ent.base.start_char_index)
637+
self.assertEqual(ann['end'], ent.base.end_char_index)
638+
623639

624640
class CATWithDictNERSupTrainingTests(CATSupTrainingTests):
625641
from medcat.components.types import CoreComponentType

0 commit comments

Comments
 (0)