Rex-Omni / rex_omni /wrapper.py
Mountchicken's picture
Update rex_omni/wrapper.py
60f587b verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Main wrapper class for Rex Omni
"""
import json
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from .parser import convert_boxes_to_normalized_bins, parse_prediction
from .tasks import TASK_CONFIGS, TaskType, get_keypoint_config, get_task_config
class RexOmniWrapper:
"""
High-level wrapper for Rex-Omni
"""
def __init__(
self,
model_path: str,
backend: str = "transformers",
system_prompt: str = "You are a helpful assistant",
min_pixels: int = 16 * 28 * 28,
max_pixels: int = 2560 * 28 * 28,
max_tokens: int = 4096,
temperature: float = 0.0,
top_p: float = 0.8,
top_k: int = 1,
repetition_penalty: float = 1.05,
skip_special_tokens: bool = False,
stop: Optional[List[str]] = None,
**kwargs,
):
"""
Initialize the wrapper
Args:
model_path: Path to the model directory
backend: Backend type ("transformers" or "vllm")
system_prompt: System prompt for the model
min_pixels: Minimum pixels for image processing
max_pixels: Maximum pixels for image processing
max_tokens: Maximum number of tokens to generate
temperature: Controls randomness in generation
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
repetition_penalty: Penalty for repetition
skip_special_tokens: Whether to skip special tokens in output
stop: Stop sequences for generation
**kwargs: Additional arguments for model initialization
"""
self.model_path = model_path
self.backend = backend.lower()
self.system_prompt = system_prompt
self.min_pixels = min_pixels
self.max_pixels = max_pixels
# Store generation parameters
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.repetition_penalty = repetition_penalty
self.skip_special_tokens = skip_special_tokens
self.stop = stop or ["<|im_end|>"]
# Initialize model and processor
self._initialize_model(**kwargs)
def _initialize_model(self, **kwargs):
"""Initialize model and processor based on backend type"""
print(f"Initializing {self.backend} backend...")
if self.backend == "vllm":
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
# Initialize VLLM model
self.model = LLM(
model=self.model_path,
tokenizer=self.model_path,
tokenizer_mode=kwargs.get("tokenizer_mode", "slow"),
limit_mm_per_prompt=kwargs.get(
"limit_mm_per_prompt", {"image": 10, "video": 10}
),
max_model_len=kwargs.get("max_model_len", 4096),
gpu_memory_utilization=kwargs.get("gpu_memory_utilization", 0.8),
tensor_parallel_size=kwargs.get("tensor_parallel_size", 1),
trust_remote_code=kwargs.get("trust_remote_code", True),
**{
k: v
for k, v in kwargs.items()
if k
not in [
"tokenizer_mode",
"limit_mm_per_prompt",
"max_model_len",
"gpu_memory_utilization",
"tensor_parallel_size",
"trust_remote_code",
]
},
)
# Initialize processor
self.processor = AutoProcessor.from_pretrained(
self.model_path,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
# Set padding side to left for batch inference with Flash Attention
self.processor.tokenizer.padding_side = "left"
# Set up sampling parameters
self.sampling_params = SamplingParams(
max_tokens=self.max_tokens,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
top_k=self.top_k,
temperature=self.temperature,
skip_special_tokens=self.skip_special_tokens,
stop=self.stop,
)
self.model_type = "vllm"
elif self.backend == "transformers":
import torch
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
# Initialize transformers model
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype=kwargs.get("torch_dtype", torch.bfloat16),
attn_implementation=kwargs.get(
"attn_implementation", "flash_attention_2"
),
device_map=kwargs.get("device_map", "auto"),
trust_remote_code=kwargs.get("trust_remote_code", True),
**{
k: v
for k, v in kwargs.items()
if k
not in [
"torch_dtype",
"attn_implementation",
"device_map",
"trust_remote_code",
]
},
)
# Initialize processor
self.processor = AutoProcessor.from_pretrained(
self.model_path,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
use_fast=False,
)
# Set padding side to left for batch inference with Flash Attention
self.processor.tokenizer.padding_side = "left"
self.model_type = "transformers"
else:
raise ValueError(
f"Unsupported backend: {self.backend}. Choose 'transformers' or 'vllm'."
)
def inference(
self,
images: Union[Image.Image, List[Image.Image]],
task: Union[str, TaskType, List[Union[str, TaskType]]],
categories: Optional[Union[str, List[str], List[List[str]]]] = None,
keypoint_type: Optional[Union[str, List[str]]] = None,
visual_prompt_boxes: Optional[
Union[List[List[float]], List[List[List[float]]]]
] = None,
**kwargs,
) -> List[Dict[str, Any]]:
"""
Perform batch inference on images for various vision tasks.
Args:
images: Input image(s) in PIL.Image format. Can be single image or list of images.
task: Task type(s). Can be single task or list of tasks for batch processing.
Available options:
- "detection": Object detection with bounding boxes
- "pointing": Point to objects with coordinates
- "visual_prompting": Find similar objects based on reference boxes
- "keypoint": Detect keypoints for persons/hands/animals
- "ocr_box": Detect and recognize text in bounding boxes
- "ocr_polygon": Detect and recognize text in polygons
- "gui_grounding": Detect gui element and return in box format
- "gui_pointing": Point to gui element and return in point format
categories: Object categories to detect/locate. Can be:
- Single string: "person"
- List of strings: ["person", "car"] (applied to all images)
- List of lists: [["person"], ["car", "dog"]] (per-image categories)
keypoint_type: Type of keypoints for keypoint detection task.
Can be single string or list of strings for batch processing.
Options: "person", "hand", "animal"
visual_prompt_boxes: Reference bounding boxes for visual prompting task.
Can be single list or list of lists for batch processing.
Format: [[x0, y0, x1, y1], ...] or [[[x0, y0, x1, y1], ...], ...]
**kwargs: Additional arguments (reserved for future use)
Returns:
List of prediction dictionaries, one for each input image. Each dictionary contains:
- success (bool): Whether inference succeeded
- extracted_predictions (dict): Parsed predictions by category
- raw_output (str): Raw model output text
- inference_time (float): Total inference time in seconds
- num_output_tokens (int): Number of generated tokens
- num_prompt_tokens (int): Number of input tokens
- tokens_per_second (float): Generation speed
- image_size (tuple): Input image dimensions (width, height)
- task (str): Task type used
- prompt (str): Generated prompt sent to model
Examples:
# Single image object detection
results = model.inference(
images=image,
task="detection",
categories=["person", "car", "dog"]
)
# Batch processing with same task and categories
results = model.inference(
images=[img1, img2, img3],
task="detection",
categories=["person", "car"]
)
# Batch processing with different tasks per image
results = model.inference(
images=[img1, img2, img3],
task=["detection", "pointing", "keypoint"],
categories=[["person", "car"], ["dog"], ["person"]],
keypoint_type=[None, None, "person"]
)
# Batch keypoint detection with different types
results = model.inference(
images=[img1, img2],
task=["keypoint", "keypoint"],
categories=[["person"], ["hand"]],
keypoint_type=["person", "hand"]
)
# Batch visual prompting
results = model.inference(
images=[img1, img2],
task="visual_prompting",
visual_prompt_boxes=[
[[100, 100, 200, 200]],
[[50, 50, 150, 150], [300, 300, 400, 400]]
]
)
# Mixed batch processing
results = model.inference(
images=[img1, img2, img3],
task=["detection", "ocr_box", "pointing"],
categories=[["person", "car"], ["text"], ["dog"]]
)
"""
# Convert single image to list
if isinstance(images, Image.Image):
images = [images]
batch_size = len(images)
# Normalize inputs to batch format
tasks, categories_list, keypoint_types, visual_prompt_boxes_list = (
self._normalize_batch_inputs(
task, categories, keypoint_type, visual_prompt_boxes, batch_size
)
)
# Perform batch inference
return self._inference_batch(
images=images,
tasks=tasks,
categories_list=categories_list,
keypoint_types=keypoint_types,
visual_prompt_boxes_list=visual_prompt_boxes_list,
**kwargs,
)
def _normalize_batch_inputs(
self,
task: Union[str, TaskType, List[Union[str, TaskType]]],
categories: Optional[Union[str, List[str], List[List[str]]]],
keypoint_type: Optional[Union[str, List[str]]],
visual_prompt_boxes: Optional[
Union[List[List[float]], List[List[List[float]]]]
],
batch_size: int,
) -> Tuple[
List[TaskType],
List[Optional[List[str]]],
List[Optional[str]],
List[Optional[List[List[float]]]],
]:
"""Normalize all inputs to batch format"""
# Normalize tasks
if isinstance(task, (str, TaskType)):
# Single task for all images
if isinstance(task, str):
task = TaskType(task.lower())
tasks = [task] * batch_size
else:
# List of tasks
tasks = []
for t in task:
if isinstance(t, str):
tasks.append(TaskType(t.lower()))
else:
tasks.append(t)
if len(tasks) != batch_size:
raise ValueError(
f"Number of tasks ({len(tasks)}) must match number of images ({batch_size})"
)
# Normalize categories
if categories is None:
categories_list = [None] * batch_size
elif isinstance(categories, str):
# Single string for all images
categories_list = [[categories]] * batch_size
elif isinstance(categories, list):
if len(categories) == 0:
categories_list = [None] * batch_size
elif isinstance(categories[0], str):
# List of strings for all images
categories_list = [categories] * batch_size
else:
# List of lists (per-image categories)
categories_list = categories
if len(categories_list) != batch_size:
raise ValueError(
f"Number of category lists ({len(categories_list)}) must match number of images ({batch_size})"
)
else:
categories_list = [None] * batch_size
# Normalize keypoint_type
if keypoint_type is None:
keypoint_types = [None] * batch_size
elif isinstance(keypoint_type, str):
# Single keypoint type for all images
keypoint_types = [keypoint_type] * batch_size
else:
# List of keypoint types
keypoint_types = keypoint_type
if len(keypoint_types) != batch_size:
raise ValueError(
f"Number of keypoint types ({len(keypoint_types)}) must match number of images ({batch_size})"
)
# Normalize visual_prompt_boxes
if visual_prompt_boxes is None:
visual_prompt_boxes_list = [None] * batch_size
elif isinstance(visual_prompt_boxes, list):
if len(visual_prompt_boxes) == 0:
visual_prompt_boxes_list = [None] * batch_size
elif isinstance(visual_prompt_boxes[0], (int, float)):
# Single box for all images: [x0, y0, x1, y1]
visual_prompt_boxes_list = [[visual_prompt_boxes]] * batch_size
elif isinstance(visual_prompt_boxes[0], list):
if len(visual_prompt_boxes[0]) == 4 and isinstance(
visual_prompt_boxes[0][0], (int, float)
):
# List of boxes for all images: [[x0, y0, x1, y1], ...]
visual_prompt_boxes_list = [visual_prompt_boxes] * batch_size
else:
# List of lists of boxes (per-image boxes): [[[x0, y0, x1, y1], ...], ...]
visual_prompt_boxes_list = visual_prompt_boxes
if len(visual_prompt_boxes_list) != batch_size:
raise ValueError(
f"Number of visual prompt box lists ({len(visual_prompt_boxes_list)}) must match number of images ({batch_size})"
)
else:
visual_prompt_boxes_list = [None] * batch_size
else:
visual_prompt_boxes_list = [None] * batch_size
return tasks, categories_list, keypoint_types, visual_prompt_boxes_list
def _inference_batch(
self,
images: List[Image.Image],
tasks: List[TaskType],
categories_list: List[Optional[List[str]]],
keypoint_types: List[Optional[str]],
visual_prompt_boxes_list: List[Optional[List[List[float]]]],
**kwargs,
) -> List[Dict[str, Any]]:
"""Perform true batch inference"""
start_time = time.time()
batch_size = len(images)
# Prepare batch data
batch_messages = []
batch_prompts = []
batch_image_sizes = []
for i in range(batch_size):
image = images[i]
task = tasks[i]
categories = categories_list[i]
keypoint_type = keypoint_types[i]
visual_prompt_boxes = visual_prompt_boxes_list[i]
# Get image dimensions
w, h = image.size
batch_image_sizes.append((w, h))
# Generate prompt
prompt = self._generate_prompt(
task=task,
categories=categories,
keypoint_type=keypoint_type,
visual_prompt_boxes=visual_prompt_boxes,
image_width=w,
image_height=h,
)
batch_prompts.append(prompt)
# Calculate resized dimensions
resized_height, resized_width = smart_resize(
h,
w,
28,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
# Prepare messages
if self.model_type == "transformers":
messages = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
"resized_height": resized_height,
"resized_width": resized_width,
},
{"type": "text", "text": prompt},
],
},
]
else:
messages = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
"min_pixels": self.min_pixels,
"max_pixels": self.max_pixels,
},
{"type": "text", "text": prompt},
],
},
]
batch_messages.append(messages)
# Perform batch generation
if self.model_type == "vllm":
batch_outputs, batch_generation_info = self._generate_vllm_batch(
batch_messages
)
else:
batch_outputs, batch_generation_info = self._generate_transformers_batch(
batch_messages, images
)
# Parse results
results = []
total_time = time.time() - start_time
for i in range(batch_size):
raw_output = batch_outputs[i]
generation_info = batch_generation_info[i]
w, h = batch_image_sizes[i]
task = tasks[i]
prompt = batch_prompts[i]
# Parse predictions
extracted_predictions = parse_prediction(
text=raw_output,
w=w,
h=h,
task_type=task.value,
)
# Calculate resized dimensions for result
resized_height, resized_width = smart_resize(
h,
w,
28,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
result = {
"success": True,
"image_size": (w, h),
"resized_size": (resized_width, resized_height),
"task": task.value,
"prompt": prompt,
"raw_output": raw_output,
"extracted_predictions": extracted_predictions,
"inference_time": total_time, # Total batch time
**generation_info,
}
results.append(result)
return results
def _inference_single(
self,
image: Image.Image,
task: TaskType,
categories: Optional[Union[str, List[str]]] = None,
keypoint_type: Optional[str] = None,
visual_prompt_boxes: Optional[List[List[float]]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Perform inference on a single image"""
start_time = time.time()
# Get image dimensions
w, h = image.size
# Generate prompt based on task
final_prompt = self._generate_prompt(
task=task,
categories=categories,
keypoint_type=keypoint_type,
visual_prompt_boxes=visual_prompt_boxes,
image_width=w,
image_height=h,
)
# Calculate resized dimensions using smart_resize
resized_height, resized_width = smart_resize(
h,
w,
28,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
# Prepare messages
if self.model_type == "transformers":
# For transformers, use resized_height and resized_width
messages = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
"resized_height": resized_height,
"resized_width": resized_width,
},
{"type": "text", "text": final_prompt},
],
},
]
else:
# For VLLM, use min_pixels and max_pixels
messages = [
{"role": "system", "content": self.system_prompt},
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
"min_pixels": self.min_pixels,
"max_pixels": self.max_pixels,
},
{"type": "text", "text": final_prompt},
],
},
]
# Generate response
if self.model_type == "vllm":
raw_output, generation_info = self._generate_vllm(messages)
else:
raw_output, generation_info = self._generate_transformers(messages)
# Parse predictions
extracted_predictions = parse_prediction(
text=raw_output,
w=w,
h=h,
task_type=task.value,
)
# Calculate timing
total_time = time.time() - start_time
return {
"success": True,
"image_size": (w, h),
"resized_size": (resized_width, resized_height),
"task": task.value,
"prompt": final_prompt,
"raw_output": raw_output,
"extracted_predictions": extracted_predictions,
"inference_time": total_time,
**generation_info,
}
def _generate_prompt(
self,
task: TaskType,
categories: Optional[Union[str, List[str]]] = None,
keypoint_type: Optional[str] = None,
visual_prompt_boxes: Optional[List[List[float]]] = None,
image_width: int = None,
image_height: int = None,
) -> str:
"""Generate prompt based on task configuration"""
task_config = get_task_config(task)
if task == TaskType.VISUAL_PROMPTING:
if visual_prompt_boxes is None:
raise ValueError(
"Visual prompt boxes are required for visual prompting task"
)
# Convert boxes to normalized bins format
word_mapped_boxes = convert_boxes_to_normalized_bins(
visual_prompt_boxes, image_width, image_height
)
visual_prompt_dict = {"object_1": word_mapped_boxes}
visual_prompt_json = json.dumps(visual_prompt_dict)
return task_config.prompt_template.format(visual_prompt=visual_prompt_json)
elif task == TaskType.KEYPOINT:
if categories is None:
raise ValueError("Categories are required for keypoint task")
if keypoint_type is None:
raise ValueError("Keypoint type is required for keypoint task")
keypoints_list = get_keypoint_config(keypoint_type)
if keypoints_list is None:
raise ValueError(f"Unknown keypoint type: {keypoint_type}")
keypoints_str = ", ".join(keypoints_list)
categories_str = (
", ".join(categories) if isinstance(categories, list) else categories
)
return task_config.prompt_template.format(
categories=categories_str, keypoints=keypoints_str
)
else:
# Standard tasks (detection, pointing, OCR, etc.)
if task_config.requires_categories and categories is None:
raise ValueError(f"Categories are required for {task.value} task")
if categories is not None:
categories_str = (
", ".join(categories)
if isinstance(categories, list)
else categories
)
return task_config.prompt_template.format(categories=categories_str)
else:
return task_config.prompt_template.format(categories="objects")
def _generate_vllm(self, messages: List[Dict]) -> Tuple[str, Dict]:
"""Generate using VLLM model"""
# Process vision info
image_inputs, video_inputs = process_vision_info(messages)
mm_data = {"image": image_inputs}
prompt = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
llm_inputs = {
"prompt": prompt,
"multi_modal_data": mm_data,
}
# Generate
generation_start = time.time()
outputs = self.model.generate(
[llm_inputs], sampling_params=self.sampling_params
)
generation_time = time.time() - generation_start
generated_text = outputs[0].outputs[0].text
# Extract token information
output_tokens = outputs[0].outputs[0].token_ids
num_output_tokens = len(output_tokens) if output_tokens else 0
prompt_token_ids = outputs[0].prompt_token_ids
num_prompt_tokens = len(prompt_token_ids) if prompt_token_ids else 0
tokens_per_second = (
num_output_tokens / generation_time if generation_time > 0 else 0
)
return generated_text, {
"num_output_tokens": num_output_tokens,
"num_prompt_tokens": num_prompt_tokens,
"generation_time": generation_time,
"tokens_per_second": tokens_per_second,
}
def _generate_vllm_batch(
self, batch_messages: List[List[Dict]]
) -> Tuple[List[str], List[Dict]]:
"""Generate using VLLM model for batch processing"""
# Process all messages
batch_inputs = []
for messages in batch_messages:
# Process vision info
image_inputs, video_inputs = process_vision_info(messages)
mm_data = {"image": image_inputs}
prompt = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
llm_inputs = {
"prompt": prompt,
"multi_modal_data": mm_data,
}
batch_inputs.append(llm_inputs)
# Generate for entire batch
generation_start = time.time()
outputs = self.model.generate(
batch_inputs, sampling_params=self.sampling_params
)
generation_time = time.time() - generation_start
# Extract results
batch_outputs = []
batch_generation_info = []
for output in outputs:
generated_text = output.outputs[0].text
batch_outputs.append(generated_text)
# Extract token information
output_tokens = output.outputs[0].token_ids
num_output_tokens = len(output_tokens) if output_tokens else 0
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens = len(prompt_token_ids) if prompt_token_ids else 0
tokens_per_second = (
num_output_tokens / generation_time if generation_time > 0 else 0
)
generation_info = {
"num_output_tokens": num_output_tokens,
"num_prompt_tokens": num_prompt_tokens,
"generation_time": generation_time,
"tokens_per_second": tokens_per_second,
}
batch_generation_info.append(generation_info)
return batch_outputs, batch_generation_info
def _generate_transformers(self, messages: List[Dict]) -> Tuple[str, Dict]:
"""Generate using Transformers model"""
# Apply chat template
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Process inputs
generation_start = time.time()
inputs = self.processor(
text=[text],
images=[messages[1]["content"][0]["image"]],
padding=True,
return_tensors="pt",
).to(self.model.device)
# Prepare generation kwargs
generation_kwargs = {
"max_new_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"do_sample": self.temperature > 0, # Enable sampling if temperature > 0
"pad_token_id": self.processor.tokenizer.eos_token_id,
}
# Generate
with torch.no_grad():
generated_ids = self.model.generate(**inputs, **generation_kwargs)
generation_time = time.time() - generation_start
# Decode
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=self.skip_special_tokens,
clean_up_tokenization_spaces=False,
)[0]
num_output_tokens = len(generated_ids_trimmed[0])
num_prompt_tokens = len(inputs.input_ids[0])
tokens_per_second = (
num_output_tokens / generation_time if generation_time > 0 else 0
)
return output_text, {
"num_output_tokens": num_output_tokens,
"num_prompt_tokens": num_prompt_tokens,
"generation_time": generation_time,
"tokens_per_second": tokens_per_second,
}
def _generate_transformers_batch(
self, batch_messages: List[List[Dict]], batch_images: List[Image.Image]
) -> Tuple[List[str], List[Dict]]:
"""Generate using Transformers model for batch processing"""
# Prepare batch inputs
batch_texts = []
for messages in batch_messages:
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
batch_texts.append(text)
# Process inputs for batch
generation_start = time.time()
inputs = self.processor(
text=batch_texts,
images=batch_images,
padding=True,
return_tensors="pt",
).to(self.model.device)
# Prepare generation kwargs
generation_kwargs = {
"max_new_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"do_sample": self.temperature > 0,
"pad_token_id": self.processor.tokenizer.eos_token_id,
}
# Generate for entire batch
with torch.no_grad():
generated_ids = self.model.generate(**inputs, **generation_kwargs)
generation_time = time.time() - generation_start
# Decode batch results
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
batch_outputs = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=self.skip_special_tokens,
clean_up_tokenization_spaces=False,
)
# Prepare generation info for each item
batch_generation_info = []
for i, output_ids in enumerate(generated_ids_trimmed):
num_output_tokens = len(output_ids)
num_prompt_tokens = len(inputs.input_ids[i])
tokens_per_second = (
num_output_tokens / generation_time if generation_time > 0 else 0
)
generation_info = {
"num_output_tokens": num_output_tokens,
"num_prompt_tokens": num_prompt_tokens,
"generation_time": generation_time,
"tokens_per_second": tokens_per_second,
}
batch_generation_info.append(generation_info)
return batch_outputs, batch_generation_info
def get_supported_tasks(self) -> List[str]:
"""Get list of supported tasks"""
return [task.value for task in TaskType]
def get_task_info(self, task: Union[str, TaskType]) -> Dict[str, Any]:
"""Get information about a specific task"""
if isinstance(task, str):
task = TaskType(task.lower())
config = get_task_config(task)
return {
"name": config.name,
"description": config.description,
"output_format": config.output_format,
"requires_categories": config.requires_categories,
"requires_visual_prompt": config.requires_visual_prompt,
"requires_keypoint_type": config.requires_keypoint_type,
"prompt_template": config.prompt_template,
}