Skip to content

Commit 61c23ef

Browse files
authored
Merge pull request #233 from KumarLabJax/pose-v6-segmentation-optional
make segmentation optional when loading v6+ pose files
2 parents 869b61c + 720de8c commit 61c23ef

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/jabs/feature_extraction/features.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from pathlib import Path
2+
from typing import cast
23

34
import h5py
45
import numpy as np
56
import pandas as pd
67

78
import jabs.project.track_labels
89
from jabs.constants import COMPRESSION, COMPRESSION_OPTS_DEFAULT
9-
from jabs.pose_estimation import PoseEstimation, PoseHashException
10+
from jabs.pose_estimation import PoseEstimation, PoseEstimationV6, PoseHashException
1011

1112
from .base_features import BaseFeatureGroup
1213

@@ -113,7 +114,12 @@ def __init__(
113114
)
114115
self._cache_window = cache_window
115116
self._compute_social_features = pose_est.format_major_version >= 3
116-
self._compute_segmentation_features = pose_est.format_major_version >= 6
117+
118+
self._compute_segmentation_features = (
119+
pose_est.format_major_version >= 6
120+
and cast(PoseEstimationV6, pose_est).has_segmentation
121+
)
122+
117123
distance_scale = (
118124
self._distance_scale_factor if self._distance_scale_factor is not None else 1.0
119125
)

src/jabs/pose_estimation/pose_est_v6.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,8 @@ def get_segmentation_data_per_frame(self, frame_index, identity: int) -> np.ndar
143143
def format_major_version(self) -> int:
144144
"""Returns the major version of the pose file format."""
145145
return 6
146+
147+
@property
148+
def has_segmentation(self) -> bool:
149+
"""Returns True if segmentation data is available."""
150+
return self._segmentation_dict["seg_data"] is not None

0 commit comments

Comments
 (0)