Skip to content

Commit f99ca5b

Browse files
authored
Merge pull request #179 from KumarLabJax/session-tracker
initial implementation of jabs labeling session tracking
2 parents 1c1046a + e3e507e commit f99ca5b

File tree

14 files changed

+566
-73
lines changed

14 files changed

+566
-73
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "jabs-behavior-classifier"
3-
version = "0.32.2"
3+
version = "0.33.0"
44
license = "Proprietary"
55
repository = "https://github.com/KumarLabJax/JABS-behavior-classifier"
66
description = ""

src/jabs/classifier/classifier.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -622,32 +622,38 @@ def count_label_threshold(all_counts: dict):
622622
623623
all_counts is a dict with the following form
624624
{
625-
'<video name>': [
626-
(
627-
<identity>,
628-
(behavior frame count - fragmented, not behavior frame count - fragmented),
629-
(behavior bout count - fragmented, not behavior bout count - fragmented),
630-
(behavior frame count - unfragmented, not behavior frame count - unfragmented),
631-
(behavior bout count - unfragmented, not behavior bout count - unfragmented),
632-
),
633-
]
625+
'<video name>': {
626+
<identity>: {
627+
"fragmented_frame_counts": (
628+
behavior frame count: fragmented,
629+
not behavior frame count: fragmented),
630+
"fragmented_bout_counts": (
631+
behavior bout count: fragmented,
632+
not behavior bout count: fragmented
633+
),
634+
"unfragmented_frame_counts": (
635+
behavior frame count: unfragmented,
636+
not behavior frame count: unfragmented
637+
),
638+
"unfragmented_bout_counts": (
639+
behavior bout count: unfragmented,
640+
not behavior bout count: unfragmented
641+
),
642+
},
643+
}
634644
}
635645
636646
Returns:
637647
number of groups that meet label criteria
638648
639649
Note: uses "fragmented" label counts, since these reflect the counts of labels that are usable for training
640-
fragmented counts are:
641-
642-
count[1][0] - behavior frame count
643-
count[1][1] - not behavior frame count
644650
"""
645651
group_count = 0
646-
for _, counts in all_counts.items():
647-
for count in counts:
652+
for video in all_counts:
653+
for identity_count in all_counts[video].values():
648654
if (
649-
count[1][0] >= Classifier.LABEL_THRESHOLD
650-
and count[1][1] >= Classifier.LABEL_THRESHOLD
655+
identity_count["fragmented_frame_counts"][0] >= Classifier.LABEL_THRESHOLD
656+
and identity_count["fragmented_frame_counts"][1] >= Classifier.LABEL_THRESHOLD
651657
):
652658
group_count += 1
653659
return group_count

src/jabs/project/project.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .prediction_manager import PredictionManager
2121
from .project_paths import ProjectPaths
2222
from .project_utils import to_safe_name
23+
from .session_tracker import SessionTracker
2324
from .settings_manager import SettingsManager
2425
from .track_labels import TrackLabels
2526
from .video_labels import VideoLabels
@@ -52,7 +53,9 @@ class Project:
5253
project_paths: ProjectPaths instance for this project.
5354
"""
5455

55-
def __init__(self, project_path, use_cache=True, enable_video_check=True):
56+
def __init__(
57+
self, project_path, use_cache=True, enable_video_check=True, enable_session_tracker=True
58+
):
5659
self._paths = ProjectPaths(Path(project_path), use_cache=use_cache)
5760
self._paths.create_directories()
5861
self._total_project_identities = 0
@@ -62,10 +65,16 @@ def __init__(self, project_path, use_cache=True, enable_video_check=True):
6265
self._video_manager = VideoManager(self._paths, self._settings_manager, enable_video_check)
6366
self._feature_manager = FeatureManager(self._paths, self._video_manager.videos)
6467
self._prediction_manager = PredictionManager(self)
68+
self._session_tracker = SessionTracker(self, tracking_enabled=enable_session_tracker)
6569

6670
# write out the defaults to the project file
6771
self._settings_manager.save_project_file({"defaults": self.get_project_defaults()})
6872

73+
# Start a session tracker for this project.
74+
# Since the session has a reference to the Project, the Project should be fully initialized before starting
75+
# the session tracker.
76+
self._session_tracker.start_session()
77+
6978
def _validate_pose_files(self):
7079
"""Ensure all videos have corresponding pose files."""
7180
err = False
@@ -143,6 +152,11 @@ def labeler(self) -> str:
143152
"""
144153
return getpass.getuser()
145154

155+
@property
156+
def session_tracker(self) -> SessionTracker | None:
157+
"""get the session tracker for this project"""
158+
return self._session_tracker
159+
146160
def load_pose_est(self, video_path: Path) -> PoseEstimation:
147161
"""return a PoseEstimation object for a given video path
148162
@@ -391,7 +405,7 @@ def counts(self, behavior):
391405
"""
392406
counts = {}
393407
for video in self._video_manager.videos:
394-
counts[video] = self.read_counts(video, behavior)
408+
counts[video] = self.load_counts(video, behavior)
395409
return counts
396410

397411
def get_labeled_features(
@@ -544,19 +558,32 @@ def __has_pose(self, vid: str):
544558
return False
545559
return True
546560

547-
def read_counts(self, video, behavior) -> list[tuple]:
548-
"""read labeled frame and bout counts from json file
561+
def load_counts(self, video, behavior) -> dict[str, tuple[int, int]]:
562+
"""load labeled frame and bout counts from json file
549563
550564
Returns:
551-
list of labeled frame and bout counts for each identity for
552-
the specified behavior. Each element in the list is a tuple of the form
553-
(
554-
identity,
555-
(fragmented behavior frame count, fragmented not behavior frame count),
556-
(fragmented behavior bout count, fragmented not behavior bout count),
557-
(unfragmented behavior frame count, unfragmented not behavior frame count)
558-
(unfragmented behavior bout count, unfragmented not behavior bout count)
559-
)
565+
dict of labeled frame and bout counts for each identity for
566+
the specified behavior.
567+
{
568+
identity: {
569+
"fragmented_frame_counts": (
570+
fragmented behavior frame count,
571+
fragmented not behavior frame count
572+
),
573+
"fragmented_bout_counts": (
574+
fragmented behavior bout count,
575+
fragmented not behavior bout count
576+
),
577+
"unfragmented_frame_counts": (
578+
unfragmented behavior frame count,
579+
unfragmented not behavior frame count
580+
)
581+
"unfragmented_bout_counts": (
582+
unfragmented behavior bout count,
583+
unfragmented not behavior bout count
584+
)
585+
}
586+
}
560587
561588
Note: "unfragmented" counts labels where identity drops out. "fragmented" does not,
562589
so if an identity drops out during a bout, the bout will be split in the fragmented counts but will
@@ -585,7 +612,7 @@ def count_labels(
585612

586613
video_filename = Path(video).name
587614
path = self._paths.annotations_dir / Path(video_filename).with_suffix(".json")
588-
counts = []
615+
counts = {}
589616

590617
if path.exists():
591618
with path.open() as f:
@@ -605,13 +632,13 @@ def count_labels(
605632
# unless the user creates some new labels over frames without identity
606633
unfragmented_counts = fragmented_counts
607634

608-
counts.append(
609-
(
610-
identity,
611-
fragmented_counts[0],
612-
fragmented_counts[1],
613-
unfragmented_counts[0],
614-
unfragmented_counts[1],
615-
)
616-
)
635+
# identity is stored as a string in the JSON file because it's used as a key. Turn it back
636+
# into an int as used internally by JABS
637+
counts[int(identity)] = {
638+
"fragmented_frame_counts": fragmented_counts[0],
639+
"fragmented_bout_counts": fragmented_counts[1],
640+
"unfragmented_frame_counts": unfragmented_counts[0],
641+
"unfragmented_bout_counts": unfragmented_counts[1],
642+
}
643+
617644
return counts

src/jabs/project/project_paths.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, base_path: Path, use_cache: bool = True):
1616
self._prediction_dir = self._jabs_dir / "predictions"
1717
self._classifier_dir = self._jabs_dir / "classifiers"
1818
self._archive_dir = self._jabs_dir / "archive"
19+
self._session_dir = self._jabs_dir / "session"
1920
self._cache_dir = self._jabs_dir / "cache" if use_cache else None
2021

2122
self._project_file = self._jabs_dir / self.__PROJECT_FILE
@@ -65,13 +66,19 @@ def cache_dir(self) -> Path | None:
6566
"""Get the path to the cache directory."""
6667
return self._cache_dir
6768

69+
@property
70+
def session_dir(self) -> Path:
71+
"""Get the path to the session directory."""
72+
return self._session_dir
73+
6874
def create_directories(self):
6975
"""Create all necessary directories for the project."""
7076
self._annotations_dir.mkdir(parents=True, exist_ok=True)
7177
self._feature_dir.mkdir(parents=True, exist_ok=True)
7278
self._prediction_dir.mkdir(parents=True, exist_ok=True)
7379
self._classifier_dir.mkdir(parents=True, exist_ok=True)
7480
self._archive_dir.mkdir(parents=True, exist_ok=True)
81+
self._session_dir.mkdir(parents=True, exist_ok=True)
7582

7683
if self._cache_dir:
7784
self._cache_dir.mkdir(parents=True, exist_ok=True)

src/jabs/project/project_pruning.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@ def get_videos_to_prune(project: Project, behavior: str | None = None) -> list[V
2222
behavior (str | None): The behavior to check for labels. If None, checks all behaviors.
2323
"""
2424

25-
def check_label_counts(label_counts: list[tuple[str, tuple[int, int]]]) -> bool:
26-
"""Check if there are any labels for the given counts."""
27-
return any(count[1][0] > 0 or count[1][1] > 0 for count in label_counts)
25+
def check_label_counts(label_counts: dict[str, dict[str, tuple[int, int]]]) -> bool:
26+
"""Return True if any count in label_counts is greater than zero."""
27+
for identity_counts in label_counts.values():
28+
for counts in identity_counts.values():
29+
if any(count > 0 for count in counts):
30+
return True
31+
return False
2832

2933
videos_to_remove = []
3034
for video in project.video_manager.videos:
@@ -34,11 +38,11 @@ def check_label_counts(label_counts: list[tuple[str, tuple[int, int]]]) -> bool:
3438

3539
has_labels = False
3640
if behavior:
37-
counts = project.read_counts(video, behavior)
41+
counts = project.load_counts(video, behavior)
3842
has_labels = check_label_counts(counts)
3943
else:
4044
for b in project.settings_manager.behavior_names:
41-
counts = project.read_counts(video, b)
45+
counts = project.load_counts(video, b)
4246
has_labels = check_label_counts(counts)
4347

4448
# found labels for at least one behavior, so we can stop checking

0 commit comments

Comments
 (0)