Skip to content

Commit 4a336d5

Browse files
committed
make segmentation optional when loading v6+ pose files
1 parent 869b61c commit 4a336d5

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/jabs/feature_extraction/features.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from pathlib import Path
2+
from typing import cast
23

34
import h5py
45
import numpy as np
56
import pandas as pd
67

78
import jabs.project.track_labels
89
from jabs.constants import COMPRESSION, COMPRESSION_OPTS_DEFAULT
9-
from jabs.pose_estimation import PoseEstimation, PoseHashException
10+
from jabs.pose_estimation import PoseEstimation, PoseEstimationV6, PoseHashException
1011

1112
from .base_features import BaseFeatureGroup
1213

@@ -113,7 +114,15 @@ def __init__(
113114
)
114115
self._cache_window = cache_window
115116
self._compute_social_features = pose_est.format_major_version >= 3
116-
self._compute_segmentation_features = pose_est.format_major_version >= 6
117+
118+
if (
119+
pose_est.format_major_version >= 6
120+
and cast(PoseEstimationV6, pose_est).has_segmentation
121+
):
122+
self._compute_segmentation_features = True
123+
else:
124+
self._compute_segmentation_features = False
125+
117126
distance_scale = (
118127
self._distance_scale_factor if self._distance_scale_factor is not None else 1.0
119128
)

src/jabs/pose_estimation/pose_est_v6.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,8 @@ def get_segmentation_data_per_frame(self, frame_index, identity: int) -> np.ndar
143143
def format_major_version(self) -> int:
144144
"""Returns the major version of the pose file format."""
145145
return 6
146+
147+
@property
148+
def has_segmentation(self) -> bool:
149+
"""Returns True if segmentation data is available."""
150+
return self._segmentation_dict["seg_data"] is not None

0 commit comments

Comments
 (0)