@@ -231,7 +231,8 @@ def main(conf: Dict,
231231 as_half : bool = True ,
232232 image_list : Optional [Union [Path , List [str ]]] = None ,
233233 feature_path : Optional [Path ] = None ,
234- overwrite : bool = False ) -> Path :
234+ overwrite : bool = False ,
235+ mask_dir : Optional [Path ] = None ) -> Path :
235236 logger .info ('Extracting local features with configuration:'
236237 f'\n { pprint .pformat (conf )} ' )
237238
@@ -256,6 +257,14 @@ def main(conf: Dict,
256257 name = dataset .names [idx ]
257258 pred = model ({'image' : data ['image' ].to (device , non_blocking = True )})
258259 pred = {k : v [0 ].cpu ().numpy () for k , v in pred .items ()}
260+ if mask_dir is not None :
261+ mask_name = str (mask_dir / name ) + '.png'
262+ # print(mask_name)
263+ mask = cv2 .imread (mask_name )[:, :, 0 ]
264+ valid_keypoint = mask [pred ['keypoints' ][:, 1 ].astype ('int' ), pred ['keypoints' ][:, 0 ].astype ('int' )]
265+ pred ['keypoints' ] = pred ['keypoints' ][valid_keypoint > 0 ]
266+ pred ['descriptors' ] = pred ['descriptors' ][:, valid_keypoint > 0 ]
267+ pred ['scores' ] = pred ['scores' ][valid_keypoint > 0 ]
259268
260269 pred ['image_size' ] = original_size = data ['original_size' ][0 ].numpy ()
261270 if 'keypoints' in pred :
0 commit comments