diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e6a4dde..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "yolov5"] - path = yolov5 - url = https://github.com/ultralytics/yolov5 diff --git a/infer.py b/infer.py deleted file mode 100644 index adf15a4..0000000 --- a/infer.py +++ /dev/null @@ -1,166 +0,0 @@ -import torch -import cv2 -import argparse -import numpy as np -from tqdm import tqdm -from pathlib import Path -from torchvision import transforms as T - -from pose.models import get_pose_model -from pose.utils.boxes import letterbox, scale_boxes, non_max_suppression, xyxy2xywh -from pose.utils.decode import get_final_preds, get_simdr_final_preds -from pose.utils.utils import setup_cudnn, get_affine_transform, draw_keypoints -from pose.utils.utils import VideoReader, VideoWriter, WebcamStream, FPS - -import sys -sys.path.insert(0, 'yolov5') -from yolov5.models.experimental import attempt_load - - -class Pose: - def __init__(self, - det_model, - pose_model, - img_size=640, - conf_thres=0.25, - iou_thres=0.45, - ) -> None: - self.img_size = img_size - self.conf_thres = conf_thres - self.iou_thres = iou_thres - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.det_model = attempt_load(det_model, map_location=self.device) - self.det_model = self.det_model.to(self.device) - - self.model_name = pose_model - self.pose_model = get_pose_model(pose_model) - self.pose_model.load_state_dict(torch.load(pose_model, map_location='cpu')) - self.pose_model = self.pose_model.to(self.device) - self.pose_model.eval() - - self.patch_size = (192, 256) - - self.pose_transform = T.Compose([ - T.ToTensor(), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) - ]) - - self.coco_skeletons = [ - [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8], - [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7] - ] - - def preprocess(self, image): - img = letterbox(image, new_shape=self.img_size) - img = np.ascontiguousarray(img.transpose((2, 0, 1))) - img = torch.from_numpy(img).to(self.device) - img = img.float() / 255.0 - img = img[None] - return img - - def box_to_center_scale(self, boxes, pixel_std=200): - boxes = xyxy2xywh(boxes) - r = self.patch_size[0] / self.patch_size[1] - mask = boxes[:, 2] > boxes[:, 3] * r - boxes[mask, 3] = boxes[mask, 2] / r - boxes[~mask, 2] = boxes[~mask, 3] * r - boxes[:, 2:] /= pixel_std - boxes[:, 2:] *= 1.25 - return boxes - - def predict_poses(self, boxes, img): - image_patches = [] - for cx, cy, w, h in boxes: - trans = get_affine_transform(np.array([cx, cy]), np.array([w, h]), self.patch_size) - img_patch = cv2.warpAffine(img, trans, self.patch_size, flags=cv2.INTER_LINEAR) - img_patch = self.pose_transform(img_patch) - image_patches.append(img_patch) - - image_patches = torch.stack(image_patches).to(self.device) - return self.pose_model(image_patches) - - def postprocess(self, pred, img1, img0): - pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=0) - - for det in pred: - if len(det): - boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu() - boxes = self.box_to_center_scale(boxes) - outputs = self.predict_poses(boxes, img0) - - if 'simdr' in self.model_name: - coords = get_simdr_final_preds(*outputs, boxes, self.patch_size) - else: - coords = get_final_preds(outputs, boxes) - - draw_keypoints(img0, coords, self.coco_skeletons) - - @torch.no_grad() - def predict(self, image): - img = self.preprocess(image) - pred = self.det_model(img)[0] - self.postprocess(pred, img, image) - return image - - -def argument_parser(): - parser = argparse.ArgumentParser() - parser.add_argument('--source', type=str, default='assests/test.jpg') - parser.add_argument('--det-model', type=str, default='checkpoints/crowdhuman_yolov5m.pt') - parser.add_argument('--pose-model', type=str, default='checkpoints/pretrained/simdr_hrnet_w32_256x192.pth') - parser.add_argument('--img-size', type=int, default=640) - parser.add_argument('--conf-thres', type=float, default=0.4) - parser.add_argument('--iou-thres', type=float, default=0.5) - return parser.parse_args() - - -if __name__ == '__main__': - setup_cudnn() - args = argument_parser() - pose = Pose( - args.det_model, - args.pose_model, - args.img_size, - args.conf_thres, - args.iou_thres - ) - - source = Path(args.source) - - if source.is_file() and source.suffix in ['.jpg', '.png']: - image = cv2.imread(str(source)) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - output = pose.predict(image) - cv2.imwrite(f"{str(source).rsplit('.', maxsplit=1)[0]}_out.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) - - elif source.is_dir(): - files = source.glob("*.jpg") - for file in files: - image = cv2.imread(str(file)) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - output = pose.predict(image) - cv2.imwrite(f"{str(file).rsplit('.', maxsplit=1)[0]}_out.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) - - elif source.is_file() and source.suffix in ['.mp4', '.avi']: - reader = VideoReader(args.source) - writer = VideoWriter(f"{args.source.rsplit('.', maxsplit=1)[0]}_out.mp4", reader.fps) - fps = FPS(len(reader.frames)) - - for frame in tqdm(reader): - fps.start() - output = pose.predict(frame.numpy()) - fps.stop(False) - writer.update(output) - - print(f"FPS: {fps.fps}") - writer.write() - - else: - webcam = WebcamStream() - fps = FPS() - - for frame in webcam: - fps.start() - output = pose.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - fps.stop() - cv2.imshow('frame', cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) \ No newline at end of file diff --git a/pose/infer.py b/pose/infer.py new file mode 100644 index 0000000..91a843a --- /dev/null +++ b/pose/infer.py @@ -0,0 +1,236 @@ +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from pathlib import Path + +import cv2 +import numpy as np +import torch +from ultralytics import YOLO +from torchvision import transforms as T +from tqdm import tqdm + +from pose.models import get_pose_model +from pose.utils.boxes import letterbox, non_max_suppression, scale_boxes, xyxy2xywh +from pose.utils.decode import get_final_preds, get_simdr_final_preds +from pose.utils.utils import ( + FPS, + VideoReader, + VideoWriter, + WebcamStream, + draw_bbox, + draw_keypoints, + get_affine_transform, + setup_cudnn, +) + + +class Pose: + def __init__( + self, + det_model: str, + pose_model: str, + img_size: int = 640, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + ) -> None: + self.img_size = img_size + self.conf_thres = conf_thres + self.iou_thres = iou_thres + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if "yolov5" in det_model: + self.det_model_type = "yolov5" + self.det_model = torch.hub.load( + "ultralytics/yolov5", "custom", path=det_model, force_reload=True + ) + self.det_model = self.det_model.to(self.device) + else: + self.det_model_type = "yolo" + self.det_model = YOLO(det_model) + self.det_model = self.det_model.to(self.device) + + self.model_name = pose_model + self.pose_model = get_pose_model(pose_model) + self.pose_model.load_state_dict(torch.load(pose_model, map_location="cpu")) + self.pose_model = self.pose_model.to(self.device) + self.pose_model.eval() + + self.patch_size = (192, 256) + + self.pose_transform = T.Compose( + [T.ToTensor(), T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] + ) + + self.coco_skeletons = [ + [16, 14], + [14, 12], + [17, 15], + [15, 13], + [12, 13], + [6, 12], + [7, 13], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [2, 3], + [1, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [5, 7], + ] + + def preprocess(self, image): + img = letterbox(image, new_shape=self.img_size) + img = np.ascontiguousarray(img.transpose((2, 0, 1))) + img = torch.from_numpy(img).to(self.device) + img = img.float() / 255.0 + img = img[None] + return img + + def box_to_center_scale(self, boxes, pixel_std=200): + boxes = xyxy2xywh(boxes) + r = self.patch_size[0] / self.patch_size[1] + mask = boxes[:, 2] > boxes[:, 3] * r + boxes[mask, 3] = boxes[mask, 2] / r + boxes[~mask, 2] = boxes[~mask, 3] * r + boxes[:, 2:] /= pixel_std + boxes[:, 2:] *= 1.25 + return boxes + + def predict_poses(self, boxes, img): + image_patches = [] + for cx, cy, w, h in boxes: + trans = get_affine_transform( + np.array([cx, cy]), np.array([w, h]), self.patch_size + ) + img_patch = cv2.warpAffine( + img, trans, self.patch_size, flags=cv2.INTER_LINEAR + ) + img_patch = self.pose_transform(img_patch) + image_patches.append(img_patch) + + image_patches = torch.stack(image_patches).to(self.device) + return self.pose_model(image_patches) + + def postprocess(self, pred, img1, img0): + pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, classes=0) + + for det in pred: + if len(det): + boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu() + boxes = self.box_to_center_scale(boxes) + outputs = self.predict_poses(boxes, img0) + + if "simdr" in self.model_name.lower(): + coords = get_simdr_final_preds(*outputs, boxes, self.patch_size) + else: + coords = get_final_preds(outputs, boxes) + + img0 = draw_keypoints(img0, coords, self.coco_skeletons) + img0 = draw_bbox(img0, det.cpu().numpy()) + + @torch.inference_mode() + def predict(self, image): + img = self.preprocess(image) + pred = self.det_model(img)[0] + self.postprocess(pred, img, image) + return image + + +def argument_parser(): + parser = ArgumentParser( + description="Pose Estimation", + formatter_class=ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--source", + type=str, + default="assests/test.jpg", + help="Path to image, video or webcam", + ) + parser.add_argument( + "--det-model", + type=str, + default="checkpoints/crowdhuman_yolov5m.pt", + help="Human detection model", + ) + parser.add_argument( + "--pose-model", + type=str, + default="checkpoints/pretrained/simdr_hrnet_w32_256x192.pth", + help="Pose estimation model", + ) + parser.add_argument("--img-size", type=int, default=640, help="Image size") + parser.add_argument( + "--conf-thres", type=float, default=0.5, help="Confidence threshold" + ) + parser.add_argument("--iou-thres", type=float, default=0.5, help="IOU threshold") + return parser.parse_args() + + +def main(): + setup_cudnn() + args = argument_parser() + pose = Pose( + det_model=args.det_model, + pose_model=args.pose_model, + img_size=args.img_size, + conf_thres=args.conf_thres, + iou_thres=args.iou_thres, + ) + + source = Path(args.source) + + if source.is_file() and source.suffix in [".jpg", ".png"]: + image = cv2.imread(str(source)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + output = pose.predict(image) + cv2.imwrite( + f"{str(source).rsplit('.', maxsplit=1)[0]}_out.jpg", + cv2.cvtColor(output, cv2.COLOR_RGB2BGR), + ) + + elif source.is_dir(): + files = source.glob("*.jpg") + for file in files: + image = cv2.imread(str(file)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + output = pose.predict(image) + cv2.imwrite( + f"{str(file).rsplit('.', maxsplit=1)[0]}_out.jpg", + cv2.cvtColor(output, cv2.COLOR_RGB2BGR), + ) + + elif source.is_file() and source.suffix in [".mp4", ".avi"]: + reader = VideoReader(args.source) + writer = VideoWriter( + f"{args.source.rsplit('.', maxsplit=1)[0]}_out.mp4", reader.fps + ) + fps = FPS(len(reader.frames)) + + for frame in tqdm(reader): + fps.start() + output = pose.predict(frame.numpy()) + fps.stop(False) + writer.update(output) + + print(f"FPS: {fps.fps}") + writer.write() + + else: + webcam = WebcamStream() + fps = FPS() + + for frame in webcam: + fps.start() + output = pose.predict(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + fps.stop() + cv2.imwrite("frame.jpg", cv2.cvtColor(output, cv2.COLOR_RGB2BGR)) + + +if __name__ == "__main__": + main() diff --git a/pose/models/__init__.py b/pose/models/__init__.py index 27f10a0..1cf1921 100644 --- a/pose/models/__init__.py +++ b/pose/models/__init__.py @@ -1,15 +1,15 @@ from .posehrnet import PoseHRNet from .simdr import SimDR +__all__ = ["PoseHRNet", "SimDR"] -__all__ = ['PoseHRNet', 'SimDR'] - -def get_pose_model(model_path: str): - if 'posehrnet' in model_path: - model = PoseHRNet('w32' if 'w32' in model_path else 'w48') - elif 'simdr' in model_path: - model = SimDR('w32' if 'w32' in model_path else 'w48') +def get_pose_model(model_path: str) -> PoseHRNet | SimDR: + if "posehrnet" in model_path.lower(): + model = PoseHRNet("w32" if "w32" in model_path.lower() else "w48") + elif "simdr" in model_path.lower(): + model = SimDR("w32" if "w32" in model_path.lower() else "w48") else: raise NotImplementedError - return model \ No newline at end of file + + return model diff --git a/pose/models/simdr.py b/pose/models/simdr.py index 7bb80e8..f43494a 100644 --- a/pose/models/simdr.py +++ b/pose/models/simdr.py @@ -1,10 +1,16 @@ import torch -from torch import nn, Tensor +from torch import Tensor, nn + from .backbones import HRNet class SimDR(nn.Module): - def __init__(self, backbone: str = 'w32', num_joints: int = 17, image_size: tuple = (256, 192)): + def __init__( + self, + backbone: str = "w32", + num_joints: int = 17, + image_size: tuple = (256, 192), + ): super().__init__() self.backbone = HRNet(backbone) self.final_layer = nn.Conv2d(self.backbone.all_channels[0], num_joints, 1) @@ -15,14 +21,16 @@ def __init__(self, backbone: str = 'w32', num_joints: int = 17, image_size: tupl def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def init_pretrained(self, pretrained: str = None) -> None: if pretrained: - self.backbone.load_state_dict(torch.load(pretrained, map_location='cpu'), strict=False) + self.backbone.load_state_dict( + torch.load(pretrained, map_location="cpu"), strict=False + ) def forward(self, x: Tensor) -> Tensor: out = self.backbone(x) @@ -32,10 +40,13 @@ def forward(self, x: Tensor) -> Tensor: return pred_x, pred_y -if __name__ == '__main__': - from torch.nn import functional as F - model = SimDR('w32') - model.load_state_dict(torch.load('checkpoints/pretrained/simdr_hrnet_w32_256x192.pth', map_location='cpu')) +if __name__ == "__main__": + model = SimDR("w32") + model.load_state_dict( + torch.load( + "checkpoints/pretrained/simdr_hrnet_w32_256x192.pth", map_location="cpu" + ) + ) x = torch.randn(4, 3, 256, 192) px, py = model(x) print(px.shape, py.shape) diff --git a/pose/utils/boxes.py b/pose/utils/boxes.py index 0b208be..784d6b7 100644 --- a/pose/utils/boxes.py +++ b/pose/utils/boxes.py @@ -1,6 +1,6 @@ import cv2 -import torch import numpy as np +import torch from torchvision import ops @@ -18,7 +18,9 @@ def letterbox(img, new_shape=(640, 640)): top, bottom = round(pH - 0.1), round(pH + 0.1) left, right = round(pW - 0.1), round(pW + 0.1) - img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) return img @@ -31,73 +33,115 @@ def scale_boxes(boxes, orig_shape, new_shape): boxes[:, ::2] -= pad[1] boxes[:, 1::2] -= pad[0] boxes[:, :4] /= gain - + boxes[:, ::2].clamp_(0, orig_shape[1]) boxes[:, 1::2].clamp_(0, orig_shape[0]) return boxes.round() -def xywh2xyxy(x): - boxes = x.clone() +def xywh2xyxy(x: np.ndarray | torch.Tensor) -> np.ndarray: + if isinstance(x, torch.Tensor): + boxes = x.clone() + elif isinstance(x, np.ndarray): + boxes = x.copy() + else: + raise TypeError("Input must be a tensor or numpy array") + boxes[:, 0] = x[:, 0] - x[:, 2] / 2 boxes[:, 1] = x[:, 1] - x[:, 3] / 2 boxes[:, 2] = x[:, 0] + x[:, 2] / 2 boxes[:, 3] = x[:, 1] + x[:, 3] / 2 + return boxes -def xyxy2xywh(x): - y = x.clone() +def xyxy2xywh(x: np.ndarray | torch.Tensor) -> np.ndarray: + if isinstance(x, torch.Tensor): + y = x.clone() + elif isinstance(x, np.ndarray): + y = x.copy() + else: + raise TypeError("Input must be a tensor or numpy array") + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center y[:, 2] = x[:, 2] - x[:, 0] # width y[:, 3] = x[:, 3] - x[:, 1] # height - return y - - -def non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, classes=None): - candidates = pred[..., 4] > conf_thres - - max_wh = 4096 - max_nms = 30000 - max_det = 300 - output = [torch.zeros((0, 6), device=pred.device)] * pred.shape[0] - - for xi, x in enumerate(pred): - x = x[candidates[xi]] - - if not x.shape[0]: continue - - # compute conf - x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf - - # box - box = xywh2xyxy(x[:, :4]) - - # detection matrix nx6 - conf, j = x[:, 5:].max(1, keepdim=True) - x = torch.cat([box, conf, j.float()], dim=1)[conf.view(-1) > conf_thres] - - # filter by class - if classes is not None: - x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] - - # check shape - n = x.shape[0] - if not n: - continue - elif n > max_nms: - x = x[x[:, 4].argsort(descending=True)[:max_nms]] - - # batched nms - c = x[:, 5:6] * max_wh - boxes, scores = x[:, :4] + c, x[:, 4] - keep = ops.nms(boxes, scores, iou_thres) - - if keep.shape[0] > max_det: - keep = keep[:max_det] + return y - output[xi] = x[keep] - return output \ No newline at end of file +def non_max_suppression( + pred: torch.Tensor, + conf_thres: float = 0.25, + iou_thres: float = 0.45, + classes: list = None, + max_det: int = 300, +) -> list: + """ + Non-Maximum Suppression (NMS) on inference results + + Args: + pred: predictions tensor (n,7) [x, y, w, h, obj_conf, cls1_conf, cls2_conf] + conf_thres: confidence threshold + iou_thres: NMS IoU threshold + classes: filter by class (e.g. [0] for persons only) + max_det: maximum number of detections per image + + Returns: + list of detections, on (n,6) tensor per image [xyxy, conf, cls] + """ + # Ensure pred is 2D + if pred.dim() == 1: + pred = pred.unsqueeze(0) + + # Calculate confidence + conf = pred[:, 4] # objectness score + class_scores = pred[:, 5:] # class probabilities + class_conf, class_pred = class_scores.max(1) # best class confidence and prediction + confidence = conf * class_conf # combine scores + + # Filter by confidence + conf_mask = confidence > conf_thres + pred = pred[conf_mask] + confidence = confidence[conf_mask] + class_pred = class_pred[conf_mask] + + if not pred.shape[0]: # no boxes + return [torch.zeros((0, 6), device=pred.device)] + + # Convert boxes from [x, y, w, h] to [x1, y1, x2, y2] + boxes = xywh2xyxy(pred[:, :4]) + + # Filter by class + if classes is not None: + if isinstance(classes, int): + classes = [classes] + class_mask = torch.zeros_like(class_pred, dtype=torch.bool) + for c in classes: + class_mask |= class_pred == c + boxes = boxes[class_mask] + confidence = confidence[class_mask] + class_pred = class_pred[class_mask] + + if not boxes.shape[0]: # no boxes after filtering + return [torch.zeros((0, 6), device=pred.device)] + + # Sort by confidence + sorted_indices = torch.argsort(confidence, descending=True) + boxes = boxes[sorted_indices] + confidence = confidence[sorted_indices] + class_pred = class_pred[sorted_indices] + + # Apply NMS + keep = ops.nms(boxes, confidence, iou_thres) + if keep.shape[0] > max_det: + keep = keep[:max_det] + + # Combine detections into final format [x1, y1, x2, y2, conf, cls] + output = torch.zeros((keep.shape[0], 6), device=pred.device) + output[:, :4] = boxes[keep] + output[:, 4] = confidence[keep] + output[:, 5] = class_pred[keep].float() + + return [output] diff --git a/pose/utils/decode.py b/pose/utils/decode.py index d6fa353..ea26258 100644 --- a/pose/utils/decode.py +++ b/pose/utils/decode.py @@ -1,10 +1,13 @@ import math -import torch + import numpy as np +import torch from torch import Tensor -def get_simdr_final_preds(pred_x: Tensor, pred_y: Tensor, boxes: Tensor, image_size: tuple): +def get_simdr_final_preds( + pred_x: Tensor, pred_y: Tensor, boxes: Tensor, image_size: tuple +): center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy() pred_x, pred_y = pred_x.softmax(dim=2), pred_y.softmax(dim=2) @@ -29,11 +32,10 @@ def get_final_preds(heatmaps: Tensor, boxes: Tensor): py = int(math.floor(coords[n][p][1] + 0.5)) if 1 < px < W - 1 and 1 < py < H - 1: - diff = np.array([ - hm[py][px+1] - hm[py][px-1], - hm[py+1][px] - hm[py-1][px] - ]) - coords[n][p] += np.sign(diff) * .25 + diff = np.array( + [hm[py][px + 1] - hm[py][px - 1], hm[py + 1][px] - hm[py - 1][px]] + ) + coords[n][p] += np.sign(diff) * 0.25 for i in range(B): coords[i] = transform_preds(coords[i], center[i], scale[i], [W, H]) @@ -59,4 +61,4 @@ def transform_preds(coords, center, scale, output_size): target_coords = np.ones_like(coords) target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 - return target_coords \ No newline at end of file + return target_coords diff --git a/pose/utils/utils.py b/pose/utils/utils.py index f14277c..6639a25 100644 --- a/pose/utils/utils.py +++ b/pose/utils/utils.py @@ -13,28 +13,71 @@ def setup_cudnn() -> None: cudnn.deterministic = False +def draw_bbox( + img: np.ndarray, + boxes: np.ndarray, + color: tuple[int, int, int] = (0, 255, 0), + thickness: int = 2, + font_scale: float = 0.5, + text_color: tuple[int, int, int] = (255, 255, 255), + text_thickness: int = 1, +) -> np.ndarray: + for box in boxes: + x1, y1, x2, y2, conf, _ = map(int, box[:6]) + cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness) + conf_text = f"{conf:.2f}" + (text_width, text_height), _ = cv2.getTextSize( + conf_text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_thickness + ) + cv2.rectangle( + img, + (x1, y1 - text_height - 5), + (x1 + text_width + 5, y1), + color, + -1, # Filled rectangle + ) + cv2.putText( + img, + conf_text, + (x1 + 3, y1 - 4), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + text_color, + text_thickness, + ) + + return img + + def draw_coco_keypoints(img, keypoints, skeletons): - if keypoints == []: return img + if keypoints == []: + return img image = img.copy() for kpts in keypoints: for x, y, v in kpts: if v == 2: cv2.circle(image, (x, y), 4, (255, 0, 0), 2) for kid1, kid2 in skeletons: - x1, y1, v1 = kpts[kid1-1] - x2, y2, v2 = kpts[kid2-1] + x1, y1, v1 = kpts[kid1 - 1] + x2, y2, v2 = kpts[kid2 - 1] if v1 == 2 and v2 == 2: - cv2.line(image, (x1, y1), (x2, y2), (0, 255, 0), 2) - return image + cv2.line(image, (x1, y1), (x2, y2), (0, 255, 0), 2) + return image def draw_keypoints(img, keypoints, skeletons): - if keypoints == []: return img + if len(keypoints) == 0 or ( + isinstance(keypoints, np.ndarray) and keypoints.size == 0 + ): + return img + for kpts in keypoints: for x, y in kpts: cv2.circle(img, (x, y), 4, (255, 0, 0), 2, cv2.LINE_AA) for kid1, kid2 in skeletons: - cv2.line(img, kpts[kid1-1], kpts[kid2-1], (0, 255, 0), 2, cv2.LINE_AA) + cv2.line(img, kpts[kid1 - 1], kpts[kid2 - 1], (0, 255, 0), 2, cv2.LINE_AA) + + return img class WebcamStream: @@ -56,7 +99,7 @@ def __iter__(self): def __next__(self): self.count += 1 - if cv2.waitKey(1) == ord('q'): + if cv2.waitKey(1) == ord("q"): self.stop() return self.frame.copy() @@ -71,8 +114,8 @@ def __len__(self): class VideoReader: def __init__(self, video: str): - self.frames, _, info = io.read_video(video, pts_unit='sec') - self.fps = info['video_fps'] + self.frames, _, info = io.read_video(video, pts_unit="sec") + self.fps = info["video_fps"] print(f"Processing '{video}'...") print(f"Total Frames: {len(self.frames)}") @@ -130,7 +173,8 @@ def stop(self, debug=True): self.counts += 1 if self.counts == self.avg: self.fps = round(self.counts / self.accum_time) - if debug: print(f"FPS: {self.fps}") + if debug: + print(f"FPS: {self.fps}") self.counts = 0 self.accum_time = 0 @@ -168,4 +212,4 @@ def get_affine_transform(center, scale, patch_size, rot=0, inv=False): src[2:, :] = get_3rd_point(src[0, :], src[1, :]) dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) - return cv2.getAffineTransform(dst, src) if inv else cv2.getAffineTransform(src, dst) \ No newline at end of file + return cv2.getAffineTransform(dst, src) if inv else cv2.getAffineTransform(src, dst) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..db7ea20 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "pose-estimation" +version = "0.1.0" +description = "Top-Down Multi-person Pose Estimation" +readme = "README.md" +requires-python = ">=3.10.12" +dependencies = [ + "gitpython~=3.1.43", + "numpy~=1.26.4", + "opencv-python-headless~=4.10.0.84", + "torch~=2.5.1", + "tqdm~=4.67.0", + "ultralytics~=8.3.23", +] +scripts = { pose = "pose.infer:main" } + +[tool.hatch.build.targets.wheel] +packages = ["pose"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7b32f69..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -opencv-python -numpy -tqdm diff --git a/yolov5 b/yolov5 deleted file mode 160000 index aa18599..0000000 --- a/yolov5 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit aa1859909c96d5e1fc839b2746b45038ee8465c9