Skip to content

Commit 913a063

Browse files
authored
Merge pull request #36 from KumarLabJax/select-feature-subsets
Adds the ability to control feature subsets for classifiers. Changes ownership of project settings from UI components to project. Classifiers make a copy of project settings.
2 parents 46a89e9 + bc12346 commit 913a063

File tree

16 files changed

+524
-527
lines changed

16 files changed

+524
-527
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.idea
2+
.coverage
23
*.pyc
34
test-reports
45
.DS_Store

classify.py

Lines changed: 31 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,16 @@ def train_and_classify(
4141
override_classifier: typing.Optional[ClassifierType] = None,
4242
fps=DEFAULT_FPS,
4343
feature_dir: typing.Optional[str] = None):
44-
try:
45-
training_file, _ = load_training_data(training_file_path)
46-
except OSError as e:
47-
sys.exit(f"Unable to open training data\n{e}")
48-
49-
behavior = training_file['behavior']
50-
window_size = training_file['window_size']
51-
use_social = training_file['has_social_features']
44+
if not training_file_path.exists():
45+
sys.exit(f"Unable to open training data\n")
5246

5347
classifier = train(training_file_path, override_classifier)
54-
classify_pose(classifier, input_pose_file, out_dir, behavior, window_size,
55-
use_social, fps, feature_dir)
48+
classify_pose(classifier, input_pose_file, out_dir, behavior, fps, feature_dir)
5649

5750

5851
def classify_pose(classifier: Classifier, input_pose_file: Path, out_dir: Path,
59-
behavior: str, window_size: int, use_social: bool,
60-
fps=DEFAULT_FPS, feature_dir: typing.Optional[str] = None):
52+
behavior: str, fps=DEFAULT_FPS,
53+
feature_dir: typing.Optional[str] = None):
6154
pose_est = open_pose_file(input_pose_file)
6255
pose_stem = get_pose_stem(input_pose_file)
6356

@@ -67,36 +60,7 @@ def classify_pose(classifier: Classifier, input_pose_file: Path, out_dir: Path,
6760
dtype=np.int8)
6861
prediction_prob = np.zeros_like(prediction_labels, dtype=np.float32)
6962

70-
if use_social and pose_est.format_major_version < 3:
71-
print(f"Skipping {input_pose_file}")
72-
print(" classifier requires v3 or higher pose files")
73-
return
74-
75-
# make sure the pose file supports all required extended features
76-
supported_features = IdentityFeatures.get_available_extended_features(
77-
pose_est.format_major_version, pose_est.static_objects)
78-
required_features = classifier.extended_features
79-
extended_feature_check_ok = True
80-
for group, features in required_features.items():
81-
if group not in supported_features:
82-
extended_feature_check_ok = False
83-
else:
84-
for f in required_features[group]:
85-
if f not in supported_features[group]:
86-
extended_feature_check_ok = False
87-
if not extended_feature_check_ok:
88-
print(f"Skipping {input_pose_file}")
89-
print(" pose file does not support all required features")
90-
return
91-
92-
distance_scale_factor = 1.0
93-
if classifier.distance_unit == ProjectDistanceUnit.CM:
94-
if pose_est.cm_per_pixel is None:
95-
print(f"Skipping {input_pose_file}")
96-
print(" classifier uses cm distance units but pose file does not have cm_per_pixel attribute")
97-
return
98-
else:
99-
distance_scale_factor = pose_est.cm_per_pixel
63+
classifier_settings = classifier.project_settings
10064

10165
print(f"Classifying {input_pose_file}...")
10266

@@ -106,12 +70,10 @@ def classify_pose(classifier: Classifier, input_pose_file: Path, out_dir: Path,
10670
complete_as_percent=False, suffix='identities')
10771

10872
features = IdentityFeatures(
109-
input_pose_file, curr_id, feature_dir, pose_est, fps=fps,
110-
distance_scale_factor=distance_scale_factor,
111-
extended_features=classifier.extended_features
112-
).get_features(window_size, use_social)
113-
per_frame_features = pd.DataFrame(IdentityFeatures.merge_per_frame_features(features['per_frame'], use_social))
114-
window_features = pd.DataFrame(IdentityFeatures.merge_window_features(features['window'], use_social))
73+
input_pose_file, curr_id, feature_dir, pose_est, fps=fps, op_settings=classifier_settings
74+
).get_features(classifier_settings['window_size'])
75+
per_frame_features = pd.DataFrame(IdentityFeatures.merge_per_frame_features(features['per_frame']))
76+
window_features = pd.DataFrame(IdentityFeatures.merge_window_features(features['window']))
11577

11678
data = Classifier.combine_data(per_frame_features, window_features)
11779

@@ -154,18 +116,21 @@ def train(
154116
) -> Classifier:
155117

156118
try:
157-
training_file, _ = load_training_data(training_file)
119+
loaded_training_data, _ = load_training_data(training_file)
158120
except OSError as e:
159121
sys.exit(f"Unable to open training data\n{e}")
160122

161-
behavior = training_file['behavior']
123+
behavior = loaded_training_data['behavior']
162124

163125
classifier = Classifier()
126+
classifier.set_dict_settings(loaded_training_data['settings'])
127+
128+
# Override the classifier type
164129
if override_classifier is not None:
165130
classifier_type = override_classifier
166131
else:
167132
classifier_type = ClassifierType(
168-
training_file['classifier_type'])
133+
loaded_training_data['classifier_type'])
169134

170135
if classifier_type in classifier.classifier_choices():
171136
classifier.set_classifier(classifier_type)
@@ -177,27 +142,21 @@ def train(
177142
print("Training classifier for:", behavior)
178143
print(" Classifier Type: "
179144
f"{__CLASSIFIER_CHOICES[classifier.classifier_type]}")
180-
print(f" Window Size: {training_file['window_size']}")
181-
print(f" Social: {training_file['has_social_features']}")
182-
print(f" Balanced Labels: {training_file['balance_labels']}")
183-
print(f" Symmetric Behavior: {training_file['symmetric']}")
184-
print(f" Distance Unit: {training_file['distance_unit'].name}")
185-
186-
training_features = classifier.combine_data(training_file['per_frame'],
187-
training_file['window'])
145+
print(f" Window Size: {loaded_training_data['settings']['window_size']}")
146+
print(f" Social: {loaded_training_data['settings']['social']}")
147+
print(f" Balanced Labels: {loaded_training_data['settings']['balance_labels']}")
148+
print(f" Symmetric Behavior: {loaded_training_data['settings']['symmetric_behavior']}")
149+
print(f" CM Units: {loaded_training_data['settings']['cm_units']}")
150+
151+
training_features = classifier.combine_data(loaded_training_data['per_frame'],
152+
loaded_training_data['window'])
188153
classifier.train(
189154
{
190155
'training_data': training_features,
191-
'training_labels': training_file['labels']
156+
'training_labels': loaded_training_data['labels']
192157
},
193158
behavior,
194-
training_file['window_size'],
195-
training_file['has_social_features'],
196-
training_file['balance_labels'],
197-
training_file['symmetric'],
198-
training_file['extended_features'],
199-
training_file['distance_unit'],
200-
random_seed=training_file['training_seed']
159+
random_seed=loaded_training_data['training_seed']
201160
)
202161

203162
return classifier
@@ -296,9 +255,7 @@ def classify_main():
296255
sys.exit(e)
297256

298257
behavior = classifier.behavior_name
299-
window_size = classifier.window_size
300-
use_social = classifier.uses_social
301-
distance_unit = classifier.distance_unit
258+
classifier_settings = classifier.project_settings
302259

303260
print(f"Classifying using trained classifier: {args.classifier}")
304261
try:
@@ -307,12 +264,11 @@ def classify_main():
307264
except KeyError:
308265
sys.exit("Error: Classifier type not supported on this platform")
309266
print(f" Behavior: {behavior}")
310-
print(f" Window Size: {window_size}")
311-
print(f" Social: {use_social}")
312-
print(f" Distance Unit: {distance_unit.name}")
267+
print(f" Window Size: {classifier_settings['window_size']}")
268+
print(f" Social: {classifier_settings['social']}")
269+
print(f" CM Units: {classifier_settings['cm_units']}")
313270

314-
classify_pose(classifier, in_pose_path, out_dir, behavior, window_size,
315-
use_social, fps=args.fps, feature_dir=args.feature_dir)
271+
classify_pose(classifier, in_pose_path, out_dir, behavior, fps=args.fps, feature_dir=args.feature_dir)
316272

317273

318274
def train_main():

docs/user_guide/user_guide.md

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,32 @@
44

55
A JABS project is a directory of video files and their corresponding pose
66
estimation files. The first time a project directory is opened in JABS, it will
7-
create a subdirectory called "rotta", which contains various files created by
7+
create a subdirectory called "jabs", which contains various files created by
88
JABS to save project state, including labels and current predictions.
99

1010
### Example JABS project directory listing:
1111

1212
```text
13-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_08-00-00_500.avi
14-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_08-00-00_500_pose_est_v3.h5
15-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_09-00-00_500.avi
16-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_09-00-00_500_pose_est_v3.h5
17-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_10-00-00_500.avi
18-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_10-00-00_500_pose_est_v3.h5
19-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_11-00-00_500.avi
20-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_11-00-00_500_pose_est_v3.h5
21-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_12-00-00_500.avi
22-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_12-00-00_500_pose_est_v3.h5
23-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_13-00-00_500.avi
24-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_13-00-00_500_pose_est_v3.h5
25-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_14-00-00_500.avi
26-
NV11-CBAX2+2019-07-26+MDX0009_2019-07-26_14-00-00_500_pose_est_v3.h5 rotta
13+
VIDEO_1.avi
14+
VIDEO_1_pose_est_v3.h5
15+
VIDEO_2.avi
16+
VIDEO_2_pose_est_v3.h5
17+
VIDEO_3.avi
18+
VIDEO_3_pose_est_v3.h5
19+
VIDEO_4.avi
20+
VIDEO_4_pose_est_v3.h5
21+
VIDEO_5.avi
22+
VIDEO_5_pose_est_v3.h5
23+
VIDEO_6.avi
24+
VIDEO_6_pose_est_v3.h5
25+
VIDEO_7.avi
26+
VIDEO_7_pose_est_v3.h5 jabs
2727
```
2828

29-
## Initializing A JABS Project Directory
29+
## Initializing A JABS Project Directory
3030

3131
The first time you open a project directory in with JABS it will create the "
32-
rotta" subdirectory. Features will be computed the first time the "Train" button
32+
jabs" subdirectory. Features will be computed the first time the "Train" button
3333
is clicked. This can be very time consuming depending on the number and length
3434
of videos in the project directory.
3535

@@ -76,48 +76,48 @@ will use up to 4 processes.
7676

7777
`./initialize_project.py -p8 -w2 -w5 -w10 <path/to/project/dir>`
7878

79-
## The Rotta Directory
79+
## The JABS Directory
8080

81-
JABS creates a subdirectory called "rotta" inside the project directory (this
82-
directory is called "rotta" for historical reasons and may change prior to the
83-
1.0.0 release of JABS). This directory contains app-specific data such as
84-
project settings, generated features, user labels, cache files, and the latest
85-
predictions.
81+
JABS creates a subdirectory called "jabs" inside the project directory. This
82+
directory contains app-specific data such as project settings, generated
83+
features, user labels, cache files, and the latest predictions.
8684

8785
project.json This file contains project settings and metadata.
8886

89-
### rotta/annotations
87+
### jabs/annotations
9088

9189
This directory stores the user's labels, stored in one JSON file per labeled
9290
video.
9391

94-
### rotta/archive
92+
### jabs/archive
9593

9694
This directory contains archived labels. These are compressed files (gzip)
9795
containing labels for behaviors that the user has removed from the project.
9896
Rotta only archives labels. Trained classifiers and predictions are deleted if a
9997
user removes a behavior from a project.
10098

101-
### rotta/cache
99+
### jabs/cache
102100

103101
Files cached by JABS to speed up performance. Some of these files may not be
104102
portable, so this directory should be deleted if a JABS project is copied to a
105103
different platform.
106104

107-
### rotta/classifiers
105+
### jabs/classifiers
108106

109107
This directory contains trained classifiers. Currently, these are stored in
110-
Python Pickle files and should be considered non-portable.
108+
Python Pickle files and should be considered non-portable. While non-portable,
109+
these files can be used alongside `classify.py classify --classifier` for
110+
predicting on the same machine the gui running the training.
111111

112-
### rotta/features
112+
### jabs/features
113113

114114
This directory contains the computed features. There is one directory per
115115
project video, and within each video directory there will be one feature
116-
directory per identity. Feature files are usually portable, but JABS may need
116+
directory per identity. Feature files are portable, but JABS may need
117117
to recompute the features if they were created with a different version of
118118
JABS.
119119

120-
### rotta/predictions
120+
### jabs/predictions
121121

122122
This directory contains prediction files. There will be one subdirectory per
123123
behavior containing one prediction file per video. Prediction files are
@@ -173,9 +173,6 @@ tool (`classify.py`).
173173
means that 11 frames are included into the window feature calculations for
174174
each frame (5 previous frames, current frame, 5 following frames).
175175
- **New Window Size:** Add a new window size to the project.
176-
- **Social Feature Toggle:** Turn on/off social features (disabled if project
177-
includes pose file version 2). Allows training a classifier backwards
178-
compatible with V2 pose files using V3 or higher poses.
179176
- **Label Balancing Toggle:** Balances the training data by downsampling the class with more labels such that the distribution is equal.
180177
- **Symmetric Behavior Toggle:** Tells the classifier that the behavior is symmetric. A symmetric behavior is when left and right features are interchangeable.
181178
- **All k-fold Toggle:** Uses the maximum number of cross validation folds. Useful when you wish to compare classifier performance and may have an outlier that can be held-out.
@@ -222,6 +219,22 @@ tool (`classify.py`).
222219
mouse
223220
- **View→Overlay Landmarks:** toggle the overlay of arena landmarks over the
224221
video.
222+
- **Features:** Menu item for controlling per-behavior classifier settings.
223+
Menu items are disabled when at least 1 pose file in the project does not
224+
contain the data to calculate features.
225+
- **Features→CM Units:** toggle using CM or pixel units
226+
(Warning! Changing this will require features to be re-calculated)
227+
- **Features→Enable Window Features:** toggle using statistical window features
228+
- **Features→Enable Signal Features:** toggle using fft-based window features
229+
- **Features→Enable Social Features:** toggle using social features (v3+ projects)
230+
- **Features→Enable Corners Features:** toggle using arena corner features
231+
(v5+ projects with arena corner static object)
232+
- **Features→Enable Lixit Features:** toggle using lixit features
233+
(v5+ projects with lixit static object)
234+
- **Features→Enable Food_hopper Features:** toggle using food hopper features
235+
(v5+ projects with food hopper static object)
236+
- **Features→Enable Segmentation Features:** toggle using segmentation features
237+
(v6+ projects)
225238

226239
**Track Overlay Example:**
227240
<img src="imgs/track_overlay.png" alt="Track Overlay" width=400 />
@@ -412,7 +425,7 @@ one video file.
412425
#### Location
413426

414427
The prediction files are saved
415-
in `<JABS project dir>/rotta/predictions/<behavior_name>/<video_name>.h5` if
428+
in `<JABS project dir>/jabs/predictions/<behavior_name>/<video_name>.h5` if
416429
they were generated by the JABS GUI. The `classify.py` script saves inference
417430
files in `<out-dir>/<behavior_name>/<video_name>.h5`
418431

0 commit comments

Comments
 (0)