Spaces:
Paused
Paused
| import torch | |
| from rt_pose import PoseEstimationPipeline | |
| import cv2 | |
| import supervision as sv | |
| import numpy as np | |
| from rt_pose import PoseEstimationPipeline, PoseEstimationOutput | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class VitPose: | |
| def __init__(self): | |
| self.pipeline = PoseEstimationPipeline( | |
| object_detection_checkpoint="PekingU/rtdetr_r50vd_coco_o365", | |
| pose_estimation_checkpoint="usyd-community/vitpose-plus-small", | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| dtype=torch.bfloat16, | |
| compile=True, # or True to get more speedup | |
| ) | |
| self.output_video_path = None | |
| self.video_metadata = {} | |
| def video_to_frames(self,video): | |
| frames = [] | |
| cap = cv2.VideoCapture(video) | |
| self.video_metadata = { | |
| "fps": cap.get(cv2.CAP_PROP_FPS), | |
| "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), | |
| "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), | |
| } | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| return frames | |
| def run(self,video): | |
| frames = self.video_to_frames(video) | |
| annotated_frames = [] | |
| for i, frame in enumerate(frames): | |
| logger.info(f"Processing frame {i} of {len(frames)}") | |
| output = self.pipeline(frame) | |
| annotated_frame = self.visualize_output(frame,output) | |
| annotated_frames.append(annotated_frame) | |
| logger.info(f"Processed {len(annotated_frames)} frames") | |
| return annotated_frames | |
| def visualize_output(self,image: np.ndarray, output: PoseEstimationOutput, confidence: float = 0.3) -> np.ndarray: | |
| """ | |
| Visualize pose estimation output. | |
| """ | |
| keypoints_xy = output.keypoints_xy.float().cpu().numpy() | |
| scores = output.scores.float().cpu().numpy() | |
| # Supervision will not draw vertices with `0` score | |
| # and coordinates with `(0, 0)` value | |
| invisible_keypoints = scores < confidence | |
| scores[invisible_keypoints] = 0 | |
| keypoints_xy[invisible_keypoints] = 0 | |
| keypoints = sv.KeyPoints(xy=keypoints_xy, confidence=scores) | |
| _, y_min, _, y_max = output.person_boxes_xyxy.T | |
| height = int((y_max - y_min).mean().item()) | |
| radius = max(height // 100, 4) | |
| thickness = max(height // 200, 2) | |
| edge_annotator = sv.EdgeAnnotator(color=sv.Color.YELLOW, thickness=thickness) | |
| vertex_annotator = sv.VertexAnnotator(color=sv.Color.ROBOFLOW, radius=radius) | |
| annotated_frame = image.copy() | |
| annotated_frame = edge_annotator.annotate(annotated_frame, keypoints) | |
| annotated_frame = vertex_annotator.annotate(annotated_frame, keypoints) | |
| return annotated_frame | |
| def frames_to_video(self, frames): | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| height = self.video_metadata["height"] | |
| width = self.video_metadata["width"] | |
| # Always ensure vertical orientation | |
| rotate = width > height # Rotate only if the video is in landscape mode | |
| # For the VideoWriter, we need to specify the dimensions of the output frames | |
| if rotate: | |
| print(f"Original dimensions: {width}x{height}, Rotated dimensions: {height}x{width}") | |
| out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (height, width)) | |
| else: | |
| print(f"Dimensions: {width}x{height}") | |
| out = cv2.VideoWriter(self.output_video_path, fourcc, self.video_metadata["fps"], (width, height)) | |
| for frame in frames: | |
| if rotate: | |
| # Rotate landscape videos 90 degrees to make them vertical | |
| rotated_frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| out.write(rotated_frame) | |
| else: | |
| # Already vertical, no rotation needed | |
| out.write(frame) | |
| out.release() | |