11# Copyright (c) OpenMMLab. All rights reserved.
22from collections import namedtuple
3+ from copy import deepcopy
34from itertools import product
45from typing import Any , List , Optional , Tuple
56
67import numpy as np
78import torch
9+ from mmengine import dump
810from munkres import Munkres
911from torch import Tensor
1012
@@ -75,7 +77,9 @@ def _init_group():
7577 tag_list = [])
7678 return _group
7779
78- for i in keypoint_order :
80+ group_history = []
81+
82+ for idx , i in enumerate (keypoint_order ):
7983 # Get all valid candidate of the i-th keypoints
8084 valid = vals [i ] > val_thr
8185 if not valid .any ():
@@ -87,12 +91,22 @@ def _init_group():
8791
8892 if len (groups ) == 0 : # Initialize the group pool
8993 for tag , val , loc in zip (tags_i , vals_i , locs_i ):
94+
95+ # Check if the keypoint belongs to existing groups
96+ if len (groups ):
97+ prev_tags = np .stack ([g .tag_list [0 ] for g in groups ])
98+ dists = np .linalg .norm (prev_tags - tag , ord = 2 , axis = 1 )
99+ if dists .min () < 1 :
100+ continue
101+
90102 group = _init_group ()
91103 group .kpts [i ] = loc
92104 group .scores [i ] = val
93105 group .tag_list .append (tag )
94106
95107 groups .append (group )
108+ costs_copy = None
109+ matches = None
96110
97111 else : # Match keypoints to existing groups
98112 groups = groups [:max_groups ]
@@ -101,17 +115,18 @@ def _init_group():
101115 # Calculate distance matrix between group tags and tag candidates
102116 # of the i-th keypoint
103117 # Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
104- diff = tags_i [:, None ] - np .array (group_tags )[None ]
118+ diff = (tags_i [:, None ] -
119+ np .array (group_tags )[None ]).astype (np .float64 )
105120 dists = np .linalg .norm (diff , ord = 2 , axis = 2 )
106121 num_kpts , num_groups = dists .shape [:2 ]
107122
108- # Experimental cost function for keypoint-group matching
123+ # Experimental cost function for keypoint-group matching2
109124 costs = np .round (dists ) * 100 - vals_i [..., None ]
125+
110126 if num_kpts > num_groups :
111- padding = np .full ((num_kpts , num_kpts - num_groups ),
112- 1e10 ,
113- dtype = np .float32 )
127+ padding = np .full ((num_kpts , num_kpts - num_groups ), 1e10 )
114128 costs = np .concatenate ((costs , padding ), axis = 1 )
129+ costs_copy = costs .copy ()
115130
116131 # Match keypoints and groups by Munkres algorithm
117132 matches = munkres .compute (costs )
@@ -121,13 +136,30 @@ def _init_group():
121136 # Add the keypoint to the matched group
122137 group = groups [group_idx ]
123138 else :
124- # Initialize a new group with unmatched keypoint
125- group = _init_group ()
126- groups .append (group )
127-
128- group .kpts [i ] = locs_i [kpt_idx ]
129- group .scores [i ] = vals_i [kpt_idx ]
130- group .tag_list .append (tags_i [kpt_idx ])
139+ # if dists[kpt_idx].min() < 0.2:
140+ if False :
141+ group = None
142+ else :
143+ # Initialize a new group with unmatched keypoint
144+ group = _init_group ()
145+ groups .append (group )
146+ if group is not None :
147+ group .kpts [i ] = locs_i [kpt_idx ]
148+ group .scores [i ] = vals_i [kpt_idx ]
149+ group .tag_list .append (tags_i [kpt_idx ])
150+
151+ out = {
152+ 'idx' : idx ,
153+ 'i' : i ,
154+ 'costs' : costs_copy ,
155+ 'matches' : matches ,
156+ 'kpts' : np .array ([g .kpts for g in groups ]),
157+ 'scores' : np .array ([g .scores for g in groups ]),
158+ 'tag_list' : [np .array (g .tag_list ) for g in groups ],
159+ }
160+ group_history .append (deepcopy (out ))
161+
162+ dump (group_history , 'group_history.pkl' )
131163
132164 groups = groups [:max_groups ]
133165 if groups :
@@ -210,7 +242,7 @@ def __init__(
210242 decode_gaussian_kernel : int = 3 ,
211243 decode_keypoint_thr : float = 0.1 ,
212244 decode_tag_thr : float = 1.0 ,
213- decode_topk : int = 20 ,
245+ decode_topk : int = 30 ,
214246 decode_max_instances : Optional [int ] = None ,
215247 ) -> None :
216248 super ().__init__ ()
@@ -336,6 +368,12 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor,
336368 B , K , H , W = batch_heatmaps .shape
337369 L = batch_tags .shape [1 ] // K
338370
371+ # Heatmap NMS
372+ dump (batch_heatmaps .cpu ().numpy (), 'heatmaps.pkl' )
373+ batch_heatmaps = batch_heatmap_nms (batch_heatmaps ,
374+ self .decode_nms_kernel )
375+ dump (batch_heatmaps .cpu ().numpy (), 'heatmaps_nms.pkl' )
376+
339377 # shape of topk_val, top_indices: (B, K, TopK)
340378 topk_vals , topk_indices = batch_heatmaps .flatten (- 2 , - 1 ).topk (
341379 k , dim = - 1 )
@@ -433,9 +471,8 @@ def _fill_missing_keypoints(self, keypoints: np.ndarray,
433471 cost_map = np .round (dist_map ) * 100 - heatmaps [k ] # H, W
434472 y , x = np .unravel_index (np .argmin (cost_map ), shape = (H , W ))
435473 keypoints [n , k ] = [x , y ]
436- keypoint_scores [n , k ] = heatmaps [k , y , x ]
437474
438- return keypoints , keypoint_scores
475+ return keypoints
439476
440477 def batch_decode (self , batch_heatmaps : Tensor , batch_tags : Tensor
441478 ) -> Tuple [List [np .ndarray ], List [np .ndarray ]]:
@@ -457,15 +494,12 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
457494 batch, each is in shape (N, K). It usually represents the
458495 confidience of the keypoint prediction
459496 """
497+
460498 B , _ , H , W = batch_heatmaps .shape
461499 assert batch_tags .shape [0 ] == B and batch_tags .shape [2 :4 ] == (H , W ), (
462500 f'Mismatched shapes of heatmap ({ batch_heatmaps .shape } ) and '
463501 f'tagging map ({ batch_tags .shape } )' )
464502
465- # Heatmap NMS
466- batch_heatmaps = batch_heatmap_nms (batch_heatmaps ,
467- self .decode_nms_kernel )
468-
469503 # Get top-k in each heatmap and and convert to numpy
470504 batch_topk_vals , batch_topk_tags , batch_topk_locs = to_numpy (
471505 self ._get_batch_topk (
@@ -489,7 +523,7 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
489523
490524 if keypoints .size > 0 :
491525 # identify missing keypoints
492- keypoints , scores = self ._fill_missing_keypoints (
526+ keypoints = self ._fill_missing_keypoints (
493527 keypoints , scores , heatmaps , tags )
494528
495529 # refine keypoint coordinates according to heatmap distribution
@@ -500,6 +534,8 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
500534 blur_kernel_size = self .decode_gaussian_kernel )
501535 else :
502536 keypoints = refine_keypoints (keypoints , heatmaps )
537+ # keypoints += 0.75
538+ keypoints += 0.5
503539
504540 batch_keypoints .append (keypoints )
505541 batch_keypoint_scores .append (scores )
0 commit comments