2020from .prediction_manager import PredictionManager
2121from .project_paths import ProjectPaths
2222from .project_utils import to_safe_name
23+ from .session_tracker import SessionTracker
2324from .settings_manager import SettingsManager
2425from .track_labels import TrackLabels
2526from .video_labels import VideoLabels
@@ -52,7 +53,9 @@ class Project:
5253 project_paths: ProjectPaths instance for this project.
5354 """
5455
55- def __init__ (self , project_path , use_cache = True , enable_video_check = True ):
56+ def __init__ (
57+ self , project_path , use_cache = True , enable_video_check = True , enable_session_tracker = True
58+ ):
5659 self ._paths = ProjectPaths (Path (project_path ), use_cache = use_cache )
5760 self ._paths .create_directories ()
5861 self ._total_project_identities = 0
@@ -62,10 +65,16 @@ def __init__(self, project_path, use_cache=True, enable_video_check=True):
6265 self ._video_manager = VideoManager (self ._paths , self ._settings_manager , enable_video_check )
6366 self ._feature_manager = FeatureManager (self ._paths , self ._video_manager .videos )
6467 self ._prediction_manager = PredictionManager (self )
68+ self ._session_tracker = SessionTracker (self , tracking_enabled = enable_session_tracker )
6569
6670 # write out the defaults to the project file
6771 self ._settings_manager .save_project_file ({"defaults" : self .get_project_defaults ()})
6872
73+ # Start a session tracker for this project.
74+ # Since the session has a reference to the Project, the Project should be fully initialized before starting
75+ # the session tracker.
76+ self ._session_tracker .start_session ()
77+
6978 def _validate_pose_files (self ):
7079 """Ensure all videos have corresponding pose files."""
7180 err = False
@@ -143,6 +152,11 @@ def labeler(self) -> str:
143152 """
144153 return getpass .getuser ()
145154
155+ @property
156+ def session_tracker (self ) -> SessionTracker | None :
157+ """get the session tracker for this project"""
158+ return self ._session_tracker
159+
146160 def load_pose_est (self , video_path : Path ) -> PoseEstimation :
147161 """return a PoseEstimation object for a given video path
148162
@@ -391,7 +405,7 @@ def counts(self, behavior):
391405 """
392406 counts = {}
393407 for video in self ._video_manager .videos :
394- counts [video ] = self .read_counts (video , behavior )
408+ counts [video ] = self .load_counts (video , behavior )
395409 return counts
396410
397411 def get_labeled_features (
@@ -544,19 +558,32 @@ def __has_pose(self, vid: str):
544558 return False
545559 return True
546560
547- def read_counts (self , video , behavior ) -> list [ tuple ]:
548- """read labeled frame and bout counts from json file
561+ def load_counts (self , video , behavior ) -> dict [ str , tuple [ int , int ] ]:
562+ """load labeled frame and bout counts from json file
549563
550564 Returns:
551- list of labeled frame and bout counts for each identity for
552- the specified behavior. Each element in the list is a tuple of the form
553- (
554- identity,
555- (fragmented behavior frame count, fragmented not behavior frame count),
556- (fragmented behavior bout count, fragmented not behavior bout count),
557- (unfragmented behavior frame count, unfragmented not behavior frame count)
558- (unfragmented behavior bout count, unfragmented not behavior bout count)
559- )
565+ dict of labeled frame and bout counts for each identity for
566+ the specified behavior.
567+ {
568+ identity: {
569+ "fragmented_frame_counts": (
570+ fragmented behavior frame count,
571+ fragmented not behavior frame count
572+ ),
573+ "fragmented_bout_counts": (
574+ fragmented behavior bout count,
575+ fragmented not behavior bout count
576+ ),
577+ "unfragmented_frame_counts": (
578+ unfragmented behavior frame count,
579+ unfragmented not behavior frame count
580+ )
581+ "unfragmented_bout_counts": (
582+ unfragmented behavior bout count,
583+ unfragmented not behavior bout count
584+ )
585+ }
586+ }
560587
561588 Note: "unfragmented" counts labels where identity drops out. "fragmented" does not,
562589 so if an identity drops out during a bout, the bout will be split in the fragmented counts but will
@@ -585,7 +612,7 @@ def count_labels(
585612
586613 video_filename = Path (video ).name
587614 path = self ._paths .annotations_dir / Path (video_filename ).with_suffix (".json" )
588- counts = []
615+ counts = {}
589616
590617 if path .exists ():
591618 with path .open () as f :
@@ -605,13 +632,13 @@ def count_labels(
605632 # unless the user creates some new labels over frames without identity
606633 unfragmented_counts = fragmented_counts
607634
608- counts . append (
609- (
610- identity ,
611- fragmented_counts [0 ],
612- fragmented_counts [1 ],
613- unfragmented_counts [0 ],
614- unfragmented_counts [1 ],
615- )
616- )
635+ # identity is stored as a string in the JSON file because it's used as a key. Turn it back
636+ # into an int as used internally by JABS
637+ counts [ int ( identity )] = {
638+ "fragmented_frame_counts" : fragmented_counts [0 ],
639+ "fragmented_bout_counts" : fragmented_counts [1 ],
640+ "unfragmented_frame_counts" : unfragmented_counts [0 ],
641+ "unfragmented_bout_counts" : unfragmented_counts [1 ],
642+ }
643+
617644 return counts
0 commit comments