Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Main wrapper class for Rex Omni | |
| """ | |
| import json | |
| import time | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from PIL import Image | |
| from qwen_vl_utils import process_vision_info, smart_resize | |
| from .parser import convert_boxes_to_normalized_bins, parse_prediction | |
| from .tasks import TASK_CONFIGS, TaskType, get_keypoint_config, get_task_config | |
| class RexOmniWrapper: | |
| """ | |
| High-level wrapper for Rex-Omni | |
| """ | |
| def __init__( | |
| self, | |
| model_path: str, | |
| backend: str = "transformers", | |
| system_prompt: str = "You are a helpful assistant", | |
| min_pixels: int = 16 * 28 * 28, | |
| max_pixels: int = 2560 * 28 * 28, | |
| max_tokens: int = 4096, | |
| temperature: float = 0.0, | |
| top_p: float = 0.8, | |
| top_k: int = 1, | |
| repetition_penalty: float = 1.05, | |
| skip_special_tokens: bool = False, | |
| stop: Optional[List[str]] = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Initialize the wrapper | |
| Args: | |
| model_path: Path to the model directory | |
| backend: Backend type ("transformers" or "vllm") | |
| system_prompt: System prompt for the model | |
| min_pixels: Minimum pixels for image processing | |
| max_pixels: Maximum pixels for image processing | |
| max_tokens: Maximum number of tokens to generate | |
| temperature: Controls randomness in generation | |
| top_p: Nucleus sampling parameter | |
| top_k: Top-k sampling parameter | |
| repetition_penalty: Penalty for repetition | |
| skip_special_tokens: Whether to skip special tokens in output | |
| stop: Stop sequences for generation | |
| **kwargs: Additional arguments for model initialization | |
| """ | |
| self.model_path = model_path | |
| self.backend = backend.lower() | |
| self.system_prompt = system_prompt | |
| self.min_pixels = min_pixels | |
| self.max_pixels = max_pixels | |
| # Store generation parameters | |
| self.max_tokens = max_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.top_k = top_k | |
| self.repetition_penalty = repetition_penalty | |
| self.skip_special_tokens = skip_special_tokens | |
| self.stop = stop or ["<|im_end|>"] | |
| # Initialize model and processor | |
| self._initialize_model(**kwargs) | |
| def _initialize_model(self, **kwargs): | |
| """Initialize model and processor based on backend type""" | |
| print(f"Initializing {self.backend} backend...") | |
| if self.backend == "vllm": | |
| from transformers import AutoProcessor | |
| from vllm import LLM, SamplingParams | |
| # Initialize VLLM model | |
| self.model = LLM( | |
| model=self.model_path, | |
| tokenizer=self.model_path, | |
| tokenizer_mode=kwargs.get("tokenizer_mode", "slow"), | |
| limit_mm_per_prompt=kwargs.get( | |
| "limit_mm_per_prompt", {"image": 10, "video": 10} | |
| ), | |
| max_model_len=kwargs.get("max_model_len", 4096), | |
| gpu_memory_utilization=kwargs.get("gpu_memory_utilization", 0.8), | |
| tensor_parallel_size=kwargs.get("tensor_parallel_size", 1), | |
| trust_remote_code=kwargs.get("trust_remote_code", True), | |
| **{ | |
| k: v | |
| for k, v in kwargs.items() | |
| if k | |
| not in [ | |
| "tokenizer_mode", | |
| "limit_mm_per_prompt", | |
| "max_model_len", | |
| "gpu_memory_utilization", | |
| "tensor_parallel_size", | |
| "trust_remote_code", | |
| ] | |
| }, | |
| ) | |
| # Initialize processor | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_path, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| ) | |
| # Set padding side to left for batch inference with Flash Attention | |
| self.processor.tokenizer.padding_side = "left" | |
| # Set up sampling parameters | |
| self.sampling_params = SamplingParams( | |
| max_tokens=self.max_tokens, | |
| top_p=self.top_p, | |
| repetition_penalty=self.repetition_penalty, | |
| top_k=self.top_k, | |
| temperature=self.temperature, | |
| skip_special_tokens=self.skip_special_tokens, | |
| stop=self.stop, | |
| ) | |
| self.model_type = "vllm" | |
| elif self.backend == "transformers": | |
| import torch | |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration | |
| # Initialize transformers model | |
| self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| self.model_path, | |
| torch_dtype=kwargs.get("torch_dtype", torch.bfloat16), | |
| attn_implementation=kwargs.get( | |
| "attn_implementation", "flash_attention_2" | |
| ), | |
| device_map=kwargs.get("device_map", "auto"), | |
| trust_remote_code=kwargs.get("trust_remote_code", True), | |
| **{ | |
| k: v | |
| for k, v in kwargs.items() | |
| if k | |
| not in [ | |
| "torch_dtype", | |
| "attn_implementation", | |
| "device_map", | |
| "trust_remote_code", | |
| ] | |
| }, | |
| ) | |
| # Initialize processor | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_path, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| use_fast=False, | |
| ) | |
| # Set padding side to left for batch inference with Flash Attention | |
| self.processor.tokenizer.padding_side = "left" | |
| self.model_type = "transformers" | |
| else: | |
| raise ValueError( | |
| f"Unsupported backend: {self.backend}. Choose 'transformers' or 'vllm'." | |
| ) | |
| def inference( | |
| self, | |
| images: Union[Image.Image, List[Image.Image]], | |
| task: Union[str, TaskType, List[Union[str, TaskType]]], | |
| categories: Optional[Union[str, List[str], List[List[str]]]] = None, | |
| keypoint_type: Optional[Union[str, List[str]]] = None, | |
| visual_prompt_boxes: Optional[ | |
| Union[List[List[float]], List[List[List[float]]]] | |
| ] = None, | |
| **kwargs, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Perform batch inference on images for various vision tasks. | |
| Args: | |
| images: Input image(s) in PIL.Image format. Can be single image or list of images. | |
| task: Task type(s). Can be single task or list of tasks for batch processing. | |
| Available options: | |
| - "detection": Object detection with bounding boxes | |
| - "pointing": Point to objects with coordinates | |
| - "visual_prompting": Find similar objects based on reference boxes | |
| - "keypoint": Detect keypoints for persons/hands/animals | |
| - "ocr_box": Detect and recognize text in bounding boxes | |
| - "ocr_polygon": Detect and recognize text in polygons | |
| - "gui_grounding": Detect gui element and return in box format | |
| - "gui_pointing": Point to gui element and return in point format | |
| categories: Object categories to detect/locate. Can be: | |
| - Single string: "person" | |
| - List of strings: ["person", "car"] (applied to all images) | |
| - List of lists: [["person"], ["car", "dog"]] (per-image categories) | |
| keypoint_type: Type of keypoints for keypoint detection task. | |
| Can be single string or list of strings for batch processing. | |
| Options: "person", "hand", "animal" | |
| visual_prompt_boxes: Reference bounding boxes for visual prompting task. | |
| Can be single list or list of lists for batch processing. | |
| Format: [[x0, y0, x1, y1], ...] or [[[x0, y0, x1, y1], ...], ...] | |
| **kwargs: Additional arguments (reserved for future use) | |
| Returns: | |
| List of prediction dictionaries, one for each input image. Each dictionary contains: | |
| - success (bool): Whether inference succeeded | |
| - extracted_predictions (dict): Parsed predictions by category | |
| - raw_output (str): Raw model output text | |
| - inference_time (float): Total inference time in seconds | |
| - num_output_tokens (int): Number of generated tokens | |
| - num_prompt_tokens (int): Number of input tokens | |
| - tokens_per_second (float): Generation speed | |
| - image_size (tuple): Input image dimensions (width, height) | |
| - task (str): Task type used | |
| - prompt (str): Generated prompt sent to model | |
| Examples: | |
| # Single image object detection | |
| results = model.inference( | |
| images=image, | |
| task="detection", | |
| categories=["person", "car", "dog"] | |
| ) | |
| # Batch processing with same task and categories | |
| results = model.inference( | |
| images=[img1, img2, img3], | |
| task="detection", | |
| categories=["person", "car"] | |
| ) | |
| # Batch processing with different tasks per image | |
| results = model.inference( | |
| images=[img1, img2, img3], | |
| task=["detection", "pointing", "keypoint"], | |
| categories=[["person", "car"], ["dog"], ["person"]], | |
| keypoint_type=[None, None, "person"] | |
| ) | |
| # Batch keypoint detection with different types | |
| results = model.inference( | |
| images=[img1, img2], | |
| task=["keypoint", "keypoint"], | |
| categories=[["person"], ["hand"]], | |
| keypoint_type=["person", "hand"] | |
| ) | |
| # Batch visual prompting | |
| results = model.inference( | |
| images=[img1, img2], | |
| task="visual_prompting", | |
| visual_prompt_boxes=[ | |
| [[100, 100, 200, 200]], | |
| [[50, 50, 150, 150], [300, 300, 400, 400]] | |
| ] | |
| ) | |
| # Mixed batch processing | |
| results = model.inference( | |
| images=[img1, img2, img3], | |
| task=["detection", "ocr_box", "pointing"], | |
| categories=[["person", "car"], ["text"], ["dog"]] | |
| ) | |
| """ | |
| # Convert single image to list | |
| if isinstance(images, Image.Image): | |
| images = [images] | |
| batch_size = len(images) | |
| # Normalize inputs to batch format | |
| tasks, categories_list, keypoint_types, visual_prompt_boxes_list = ( | |
| self._normalize_batch_inputs( | |
| task, categories, keypoint_type, visual_prompt_boxes, batch_size | |
| ) | |
| ) | |
| # Perform batch inference | |
| return self._inference_batch( | |
| images=images, | |
| tasks=tasks, | |
| categories_list=categories_list, | |
| keypoint_types=keypoint_types, | |
| visual_prompt_boxes_list=visual_prompt_boxes_list, | |
| **kwargs, | |
| ) | |
| def _normalize_batch_inputs( | |
| self, | |
| task: Union[str, TaskType, List[Union[str, TaskType]]], | |
| categories: Optional[Union[str, List[str], List[List[str]]]], | |
| keypoint_type: Optional[Union[str, List[str]]], | |
| visual_prompt_boxes: Optional[ | |
| Union[List[List[float]], List[List[List[float]]]] | |
| ], | |
| batch_size: int, | |
| ) -> Tuple[ | |
| List[TaskType], | |
| List[Optional[List[str]]], | |
| List[Optional[str]], | |
| List[Optional[List[List[float]]]], | |
| ]: | |
| """Normalize all inputs to batch format""" | |
| # Normalize tasks | |
| if isinstance(task, (str, TaskType)): | |
| # Single task for all images | |
| if isinstance(task, str): | |
| task = TaskType(task.lower()) | |
| tasks = [task] * batch_size | |
| else: | |
| # List of tasks | |
| tasks = [] | |
| for t in task: | |
| if isinstance(t, str): | |
| tasks.append(TaskType(t.lower())) | |
| else: | |
| tasks.append(t) | |
| if len(tasks) != batch_size: | |
| raise ValueError( | |
| f"Number of tasks ({len(tasks)}) must match number of images ({batch_size})" | |
| ) | |
| # Normalize categories | |
| if categories is None: | |
| categories_list = [None] * batch_size | |
| elif isinstance(categories, str): | |
| # Single string for all images | |
| categories_list = [[categories]] * batch_size | |
| elif isinstance(categories, list): | |
| if len(categories) == 0: | |
| categories_list = [None] * batch_size | |
| elif isinstance(categories[0], str): | |
| # List of strings for all images | |
| categories_list = [categories] * batch_size | |
| else: | |
| # List of lists (per-image categories) | |
| categories_list = categories | |
| if len(categories_list) != batch_size: | |
| raise ValueError( | |
| f"Number of category lists ({len(categories_list)}) must match number of images ({batch_size})" | |
| ) | |
| else: | |
| categories_list = [None] * batch_size | |
| # Normalize keypoint_type | |
| if keypoint_type is None: | |
| keypoint_types = [None] * batch_size | |
| elif isinstance(keypoint_type, str): | |
| # Single keypoint type for all images | |
| keypoint_types = [keypoint_type] * batch_size | |
| else: | |
| # List of keypoint types | |
| keypoint_types = keypoint_type | |
| if len(keypoint_types) != batch_size: | |
| raise ValueError( | |
| f"Number of keypoint types ({len(keypoint_types)}) must match number of images ({batch_size})" | |
| ) | |
| # Normalize visual_prompt_boxes | |
| if visual_prompt_boxes is None: | |
| visual_prompt_boxes_list = [None] * batch_size | |
| elif isinstance(visual_prompt_boxes, list): | |
| if len(visual_prompt_boxes) == 0: | |
| visual_prompt_boxes_list = [None] * batch_size | |
| elif isinstance(visual_prompt_boxes[0], (int, float)): | |
| # Single box for all images: [x0, y0, x1, y1] | |
| visual_prompt_boxes_list = [[visual_prompt_boxes]] * batch_size | |
| elif isinstance(visual_prompt_boxes[0], list): | |
| if len(visual_prompt_boxes[0]) == 4 and isinstance( | |
| visual_prompt_boxes[0][0], (int, float) | |
| ): | |
| # List of boxes for all images: [[x0, y0, x1, y1], ...] | |
| visual_prompt_boxes_list = [visual_prompt_boxes] * batch_size | |
| else: | |
| # List of lists of boxes (per-image boxes): [[[x0, y0, x1, y1], ...], ...] | |
| visual_prompt_boxes_list = visual_prompt_boxes | |
| if len(visual_prompt_boxes_list) != batch_size: | |
| raise ValueError( | |
| f"Number of visual prompt box lists ({len(visual_prompt_boxes_list)}) must match number of images ({batch_size})" | |
| ) | |
| else: | |
| visual_prompt_boxes_list = [None] * batch_size | |
| else: | |
| visual_prompt_boxes_list = [None] * batch_size | |
| return tasks, categories_list, keypoint_types, visual_prompt_boxes_list | |
| def _inference_batch( | |
| self, | |
| images: List[Image.Image], | |
| tasks: List[TaskType], | |
| categories_list: List[Optional[List[str]]], | |
| keypoint_types: List[Optional[str]], | |
| visual_prompt_boxes_list: List[Optional[List[List[float]]]], | |
| **kwargs, | |
| ) -> List[Dict[str, Any]]: | |
| """Perform true batch inference""" | |
| start_time = time.time() | |
| batch_size = len(images) | |
| # Prepare batch data | |
| batch_messages = [] | |
| batch_prompts = [] | |
| batch_image_sizes = [] | |
| for i in range(batch_size): | |
| image = images[i] | |
| task = tasks[i] | |
| categories = categories_list[i] | |
| keypoint_type = keypoint_types[i] | |
| visual_prompt_boxes = visual_prompt_boxes_list[i] | |
| # Get image dimensions | |
| w, h = image.size | |
| batch_image_sizes.append((w, h)) | |
| # Generate prompt | |
| prompt = self._generate_prompt( | |
| task=task, | |
| categories=categories, | |
| keypoint_type=keypoint_type, | |
| visual_prompt_boxes=visual_prompt_boxes, | |
| image_width=w, | |
| image_height=h, | |
| ) | |
| batch_prompts.append(prompt) | |
| # Calculate resized dimensions | |
| resized_height, resized_width = smart_resize( | |
| h, | |
| w, | |
| 28, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| ) | |
| # Prepare messages | |
| if self.model_type == "transformers": | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| "resized_height": resized_height, | |
| "resized_width": resized_width, | |
| }, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }, | |
| ] | |
| else: | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| "min_pixels": self.min_pixels, | |
| "max_pixels": self.max_pixels, | |
| }, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| }, | |
| ] | |
| batch_messages.append(messages) | |
| # Perform batch generation | |
| if self.model_type == "vllm": | |
| batch_outputs, batch_generation_info = self._generate_vllm_batch( | |
| batch_messages | |
| ) | |
| else: | |
| batch_outputs, batch_generation_info = self._generate_transformers_batch( | |
| batch_messages, images | |
| ) | |
| # Parse results | |
| results = [] | |
| total_time = time.time() - start_time | |
| for i in range(batch_size): | |
| raw_output = batch_outputs[i] | |
| generation_info = batch_generation_info[i] | |
| w, h = batch_image_sizes[i] | |
| task = tasks[i] | |
| prompt = batch_prompts[i] | |
| # Parse predictions | |
| extracted_predictions = parse_prediction( | |
| text=raw_output, | |
| w=w, | |
| h=h, | |
| task_type=task.value, | |
| ) | |
| # Calculate resized dimensions for result | |
| resized_height, resized_width = smart_resize( | |
| h, | |
| w, | |
| 28, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| ) | |
| result = { | |
| "success": True, | |
| "image_size": (w, h), | |
| "resized_size": (resized_width, resized_height), | |
| "task": task.value, | |
| "prompt": prompt, | |
| "raw_output": raw_output, | |
| "extracted_predictions": extracted_predictions, | |
| "inference_time": total_time, # Total batch time | |
| **generation_info, | |
| } | |
| results.append(result) | |
| return results | |
| def _inference_single( | |
| self, | |
| image: Image.Image, | |
| task: TaskType, | |
| categories: Optional[Union[str, List[str]]] = None, | |
| keypoint_type: Optional[str] = None, | |
| visual_prompt_boxes: Optional[List[List[float]]] = None, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| """Perform inference on a single image""" | |
| start_time = time.time() | |
| # Get image dimensions | |
| w, h = image.size | |
| # Generate prompt based on task | |
| final_prompt = self._generate_prompt( | |
| task=task, | |
| categories=categories, | |
| keypoint_type=keypoint_type, | |
| visual_prompt_boxes=visual_prompt_boxes, | |
| image_width=w, | |
| image_height=h, | |
| ) | |
| # Calculate resized dimensions using smart_resize | |
| resized_height, resized_width = smart_resize( | |
| h, | |
| w, | |
| 28, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels, | |
| ) | |
| # Prepare messages | |
| if self.model_type == "transformers": | |
| # For transformers, use resized_height and resized_width | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| "resized_height": resized_height, | |
| "resized_width": resized_width, | |
| }, | |
| {"type": "text", "text": final_prompt}, | |
| ], | |
| }, | |
| ] | |
| else: | |
| # For VLLM, use min_pixels and max_pixels | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| "min_pixels": self.min_pixels, | |
| "max_pixels": self.max_pixels, | |
| }, | |
| {"type": "text", "text": final_prompt}, | |
| ], | |
| }, | |
| ] | |
| # Generate response | |
| if self.model_type == "vllm": | |
| raw_output, generation_info = self._generate_vllm(messages) | |
| else: | |
| raw_output, generation_info = self._generate_transformers(messages) | |
| # Parse predictions | |
| extracted_predictions = parse_prediction( | |
| text=raw_output, | |
| w=w, | |
| h=h, | |
| task_type=task.value, | |
| ) | |
| # Calculate timing | |
| total_time = time.time() - start_time | |
| return { | |
| "success": True, | |
| "image_size": (w, h), | |
| "resized_size": (resized_width, resized_height), | |
| "task": task.value, | |
| "prompt": final_prompt, | |
| "raw_output": raw_output, | |
| "extracted_predictions": extracted_predictions, | |
| "inference_time": total_time, | |
| **generation_info, | |
| } | |
| def _generate_prompt( | |
| self, | |
| task: TaskType, | |
| categories: Optional[Union[str, List[str]]] = None, | |
| keypoint_type: Optional[str] = None, | |
| visual_prompt_boxes: Optional[List[List[float]]] = None, | |
| image_width: int = None, | |
| image_height: int = None, | |
| ) -> str: | |
| """Generate prompt based on task configuration""" | |
| task_config = get_task_config(task) | |
| if task == TaskType.VISUAL_PROMPTING: | |
| if visual_prompt_boxes is None: | |
| raise ValueError( | |
| "Visual prompt boxes are required for visual prompting task" | |
| ) | |
| # Convert boxes to normalized bins format | |
| word_mapped_boxes = convert_boxes_to_normalized_bins( | |
| visual_prompt_boxes, image_width, image_height | |
| ) | |
| visual_prompt_dict = {"object_1": word_mapped_boxes} | |
| visual_prompt_json = json.dumps(visual_prompt_dict) | |
| return task_config.prompt_template.format(visual_prompt=visual_prompt_json) | |
| elif task == TaskType.KEYPOINT: | |
| if categories is None: | |
| raise ValueError("Categories are required for keypoint task") | |
| if keypoint_type is None: | |
| raise ValueError("Keypoint type is required for keypoint task") | |
| keypoints_list = get_keypoint_config(keypoint_type) | |
| if keypoints_list is None: | |
| raise ValueError(f"Unknown keypoint type: {keypoint_type}") | |
| keypoints_str = ", ".join(keypoints_list) | |
| categories_str = ( | |
| ", ".join(categories) if isinstance(categories, list) else categories | |
| ) | |
| return task_config.prompt_template.format( | |
| categories=categories_str, keypoints=keypoints_str | |
| ) | |
| else: | |
| # Standard tasks (detection, pointing, OCR, etc.) | |
| if task_config.requires_categories and categories is None: | |
| raise ValueError(f"Categories are required for {task.value} task") | |
| if categories is not None: | |
| categories_str = ( | |
| ", ".join(categories) | |
| if isinstance(categories, list) | |
| else categories | |
| ) | |
| return task_config.prompt_template.format(categories=categories_str) | |
| else: | |
| return task_config.prompt_template.format(categories="objects") | |
| def _generate_vllm(self, messages: List[Dict]) -> Tuple[str, Dict]: | |
| """Generate using VLLM model""" | |
| # Process vision info | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| mm_data = {"image": image_inputs} | |
| prompt = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| llm_inputs = { | |
| "prompt": prompt, | |
| "multi_modal_data": mm_data, | |
| } | |
| # Generate | |
| generation_start = time.time() | |
| outputs = self.model.generate( | |
| [llm_inputs], sampling_params=self.sampling_params | |
| ) | |
| generation_time = time.time() - generation_start | |
| generated_text = outputs[0].outputs[0].text | |
| # Extract token information | |
| output_tokens = outputs[0].outputs[0].token_ids | |
| num_output_tokens = len(output_tokens) if output_tokens else 0 | |
| prompt_token_ids = outputs[0].prompt_token_ids | |
| num_prompt_tokens = len(prompt_token_ids) if prompt_token_ids else 0 | |
| tokens_per_second = ( | |
| num_output_tokens / generation_time if generation_time > 0 else 0 | |
| ) | |
| return generated_text, { | |
| "num_output_tokens": num_output_tokens, | |
| "num_prompt_tokens": num_prompt_tokens, | |
| "generation_time": generation_time, | |
| "tokens_per_second": tokens_per_second, | |
| } | |
| def _generate_vllm_batch( | |
| self, batch_messages: List[List[Dict]] | |
| ) -> Tuple[List[str], List[Dict]]: | |
| """Generate using VLLM model for batch processing""" | |
| # Process all messages | |
| batch_inputs = [] | |
| for messages in batch_messages: | |
| # Process vision info | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| mm_data = {"image": image_inputs} | |
| prompt = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| llm_inputs = { | |
| "prompt": prompt, | |
| "multi_modal_data": mm_data, | |
| } | |
| batch_inputs.append(llm_inputs) | |
| # Generate for entire batch | |
| generation_start = time.time() | |
| outputs = self.model.generate( | |
| batch_inputs, sampling_params=self.sampling_params | |
| ) | |
| generation_time = time.time() - generation_start | |
| # Extract results | |
| batch_outputs = [] | |
| batch_generation_info = [] | |
| for output in outputs: | |
| generated_text = output.outputs[0].text | |
| batch_outputs.append(generated_text) | |
| # Extract token information | |
| output_tokens = output.outputs[0].token_ids | |
| num_output_tokens = len(output_tokens) if output_tokens else 0 | |
| prompt_token_ids = output.prompt_token_ids | |
| num_prompt_tokens = len(prompt_token_ids) if prompt_token_ids else 0 | |
| tokens_per_second = ( | |
| num_output_tokens / generation_time if generation_time > 0 else 0 | |
| ) | |
| generation_info = { | |
| "num_output_tokens": num_output_tokens, | |
| "num_prompt_tokens": num_prompt_tokens, | |
| "generation_time": generation_time, | |
| "tokens_per_second": tokens_per_second, | |
| } | |
| batch_generation_info.append(generation_info) | |
| return batch_outputs, batch_generation_info | |
| def _generate_transformers(self, messages: List[Dict]) -> Tuple[str, Dict]: | |
| """Generate using Transformers model""" | |
| # Apply chat template | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Process inputs | |
| generation_start = time.time() | |
| inputs = self.processor( | |
| text=[text], | |
| images=[messages[1]["content"][0]["image"]], | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.model.device) | |
| # Prepare generation kwargs | |
| generation_kwargs = { | |
| "max_new_tokens": self.max_tokens, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "top_k": self.top_k, | |
| "repetition_penalty": self.repetition_penalty, | |
| "do_sample": self.temperature > 0, # Enable sampling if temperature > 0 | |
| "pad_token_id": self.processor.tokenizer.eos_token_id, | |
| } | |
| # Generate | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate(**inputs, **generation_kwargs) | |
| generation_time = time.time() - generation_start | |
| # Decode | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=self.skip_special_tokens, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| num_output_tokens = len(generated_ids_trimmed[0]) | |
| num_prompt_tokens = len(inputs.input_ids[0]) | |
| tokens_per_second = ( | |
| num_output_tokens / generation_time if generation_time > 0 else 0 | |
| ) | |
| return output_text, { | |
| "num_output_tokens": num_output_tokens, | |
| "num_prompt_tokens": num_prompt_tokens, | |
| "generation_time": generation_time, | |
| "tokens_per_second": tokens_per_second, | |
| } | |
| def _generate_transformers_batch( | |
| self, batch_messages: List[List[Dict]], batch_images: List[Image.Image] | |
| ) -> Tuple[List[str], List[Dict]]: | |
| """Generate using Transformers model for batch processing""" | |
| # Prepare batch inputs | |
| batch_texts = [] | |
| for messages in batch_messages: | |
| text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| batch_texts.append(text) | |
| # Process inputs for batch | |
| generation_start = time.time() | |
| inputs = self.processor( | |
| text=batch_texts, | |
| images=batch_images, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.model.device) | |
| # Prepare generation kwargs | |
| generation_kwargs = { | |
| "max_new_tokens": self.max_tokens, | |
| "temperature": self.temperature, | |
| "top_p": self.top_p, | |
| "top_k": self.top_k, | |
| "repetition_penalty": self.repetition_penalty, | |
| "do_sample": self.temperature > 0, | |
| "pad_token_id": self.processor.tokenizer.eos_token_id, | |
| } | |
| # Generate for entire batch | |
| with torch.no_grad(): | |
| generated_ids = self.model.generate(**inputs, **generation_kwargs) | |
| generation_time = time.time() - generation_start | |
| # Decode batch results | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| batch_outputs = self.processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=self.skip_special_tokens, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| # Prepare generation info for each item | |
| batch_generation_info = [] | |
| for i, output_ids in enumerate(generated_ids_trimmed): | |
| num_output_tokens = len(output_ids) | |
| num_prompt_tokens = len(inputs.input_ids[i]) | |
| tokens_per_second = ( | |
| num_output_tokens / generation_time if generation_time > 0 else 0 | |
| ) | |
| generation_info = { | |
| "num_output_tokens": num_output_tokens, | |
| "num_prompt_tokens": num_prompt_tokens, | |
| "generation_time": generation_time, | |
| "tokens_per_second": tokens_per_second, | |
| } | |
| batch_generation_info.append(generation_info) | |
| return batch_outputs, batch_generation_info | |
| def get_supported_tasks(self) -> List[str]: | |
| """Get list of supported tasks""" | |
| return [task.value for task in TaskType] | |
| def get_task_info(self, task: Union[str, TaskType]) -> Dict[str, Any]: | |
| """Get information about a specific task""" | |
| if isinstance(task, str): | |
| task = TaskType(task.lower()) | |
| config = get_task_config(task) | |
| return { | |
| "name": config.name, | |
| "description": config.description, | |
| "output_format": config.output_format, | |
| "requires_categories": config.requires_categories, | |
| "requires_visual_prompt": config.requires_visual_prompt, | |
| "requires_keypoint_type": config.requires_keypoint_type, | |
| "prompt_template": config.prompt_template, | |
| } | |