|
|
import copy |
|
|
import os |
|
|
import re |
|
|
import uuid |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import PIL |
|
|
from PIL import Image |
|
|
import torch |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from transformers.image_utils import ImageInput |
|
|
from transformers.processing_utils import ( |
|
|
AllKwargsForChatTemplate, |
|
|
ChatTemplateLoadKwargs, |
|
|
ProcessingKwargs, |
|
|
ProcessorMixin, |
|
|
TextInput, |
|
|
Unpack, |
|
|
VideoInput, |
|
|
) |
|
|
from transformers.tokenization_utils_base import AudioInput |
|
|
from transformers.utils import ( |
|
|
is_torch_device, |
|
|
is_torch_dtype, |
|
|
logging, |
|
|
requires_backends, |
|
|
) |
|
|
from transformers.video_utils import VideoInput, VideoMetadata, load_video |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class HCXBatchFeature(BatchFeature): |
|
|
def to(self, *args, **kwargs) -> "BatchFeature": |
|
|
""" |
|
|
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in |
|
|
different `dtypes` and sending the `BatchFeature` to a different `device`. |
|
|
|
|
|
Args: |
|
|
args (`Tuple`): |
|
|
Will be passed to the `to(...)` function of the tensors. |
|
|
kwargs (`Dict`, *optional*): |
|
|
Will be passed to the `to(...)` function of the tensors. |
|
|
To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`). |
|
|
|
|
|
Returns: |
|
|
[`BatchFeature`]: The same instance after modification. |
|
|
""" |
|
|
requires_backends(self, ["torch"]) |
|
|
import torch |
|
|
|
|
|
new_data = {} |
|
|
device = kwargs.get("device") |
|
|
non_blocking = kwargs.get("non_blocking", False) |
|
|
|
|
|
if device is None and len(args) > 0: |
|
|
|
|
|
arg = args[0] |
|
|
if is_torch_dtype(arg): |
|
|
|
|
|
pass |
|
|
elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): |
|
|
device = arg |
|
|
else: |
|
|
|
|
|
raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") |
|
|
|
|
|
for k, v in self.items(): |
|
|
|
|
|
if isinstance(v, torch.Tensor) and torch.is_floating_point(v): |
|
|
|
|
|
new_data[k] = v.to(*args, **kwargs) |
|
|
elif isinstance(v, torch.Tensor) and device is not None: |
|
|
new_data[k] = v.to(device=device, non_blocking=non_blocking) |
|
|
elif "pixel_values" in k: |
|
|
new_pixel_values_batch = [] |
|
|
for _v in v: |
|
|
pixel_values = [pixel_value.to(device=device, non_blocking=non_blocking) for pixel_value in _v] |
|
|
new_pixel_values_batch.append(pixel_values) |
|
|
new_data[k] = new_pixel_values_batch |
|
|
else: |
|
|
new_data[k] = v |
|
|
self.data = new_data |
|
|
return self |
|
|
|
|
|
|
|
|
class HCXProcessorKwargs(ProcessingKwargs, total=False): |
|
|
_defaults = { |
|
|
"text_kwargs": { |
|
|
"return_tensors": "pt", |
|
|
"calc_non_vision_query_lengths": False, |
|
|
}, |
|
|
"images_kwargs": {}, |
|
|
"audio_kwargs": {}, |
|
|
"videos_kwargs": { |
|
|
"max_image_cnt": 12, |
|
|
"max_num_grids": 9, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
class HCXProcessor(ProcessorMixin): |
|
|
attributes = ["image_processor", "tokenizer"] |
|
|
valid_kwargs = ["chat_template"] |
|
|
|
|
|
image_processor_class = "AutoImageProcessor" |
|
|
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast") |
|
|
|
|
|
def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): |
|
|
self.image_token = "<|dummy3|>" |
|
|
self.video_token = "<|_unuse_missing_100270|>" |
|
|
self.image_token_pattern = re.compile(r"<\|dummy3\|>") |
|
|
self.video_token_pattern = re.compile(r"<\|_unuse_missing_100270\|>") |
|
|
self.image_video_token_pattern = re.compile(r"<\|dummy3\|>|<\|_unuse_missing_100270\|>") |
|
|
self.image_token_id = ( |
|
|
tokenizer.image_token_id |
|
|
if getattr(tokenizer, "image_token_id", None) |
|
|
else tokenizer.convert_tokens_to_ids(self.image_token) |
|
|
) |
|
|
self.video_token_id = ( |
|
|
tokenizer.video_token_id |
|
|
if getattr(tokenizer, "video_token_id", None) |
|
|
else tokenizer.convert_tokens_to_ids(self.video_token) |
|
|
) |
|
|
super().__init__(image_processor, tokenizer, chat_template=chat_template) |
|
|
|
|
|
def apply_chat_template( |
|
|
self, |
|
|
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], |
|
|
chat_template: Optional[str] = None, |
|
|
**kwargs: Unpack[AllKwargsForChatTemplate], |
|
|
) -> str: |
|
|
model_inputs = super().apply_chat_template(conversation, chat_template, **kwargs) |
|
|
|
|
|
|
|
|
del model_inputs["vision_query_lengths_images"] |
|
|
del model_inputs["vision_query_lengths_videos"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return model_inputs |
|
|
|
|
|
def repeat_dummy_tokens(self, input_ids, target_token_id, vision_query_lengths): |
|
|
input_ids = input_ids.clone().detach() |
|
|
batch_indices, target_indices = torch.where(input_ids == target_token_id) |
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
new_input_ids = [[] for _ in range(batch_size)] |
|
|
start_indices = [0 for _ in range(batch_size)] |
|
|
counter = [0 for _ in range(batch_size)] |
|
|
for batch_idx, target_idx in zip(batch_indices, target_indices): |
|
|
start_idx = start_indices[batch_idx] |
|
|
new_input_ids[batch_idx].append(input_ids[batch_idx][start_idx:target_idx]) |
|
|
query_length = vision_query_lengths[batch_idx][counter[batch_idx]] |
|
|
new_input_ids[batch_idx].append(input_ids[batch_idx][target_idx].repeat(query_length)) |
|
|
start_indices[batch_idx] = target_idx + 1 |
|
|
counter[batch_idx] += 1 |
|
|
|
|
|
for batch_idx in range(batch_size): |
|
|
start_idx = start_indices[batch_idx] |
|
|
new_input_ids[batch_idx].append(input_ids[batch_idx][start_idx:]) |
|
|
new_input_ids[batch_idx] = torch.cat(new_input_ids[batch_idx], dim=0) |
|
|
|
|
|
new_input_ids = torch.stack(new_input_ids) |
|
|
return new_input_ids |
|
|
|
|
|
def _load_video_for_model( |
|
|
self, |
|
|
video: str, |
|
|
num_frames: Optional[int] = None, |
|
|
fps: Optional[int] = None, |
|
|
backend: str = "opencv", |
|
|
**kwargs: Unpack[HCXProcessorKwargs], |
|
|
) -> List[ImageInput]: |
|
|
""" |
|
|
Overrided function. |
|
|
|
|
|
Loads `video` to a List[PIL.Image] (llava style) |
|
|
|
|
|
Args: |
|
|
video (`str`): |
|
|
The video to convert to the numpy array format. Can be a link to video or local path. |
|
|
num_frames (`int`, *optional*): |
|
|
Number of frames to sample uniformly. If not passed, the whole video is loaded. |
|
|
fps (`int`, *optional*): |
|
|
Number of frames to sample per second. Should be passed only when `num_frames=None`. |
|
|
If not specified and `num_frames==None`, all frames are sampled. |
|
|
backend (`str`, *optional*, defaults to `"opencv"`): |
|
|
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv". |
|
|
|
|
|
Returns: |
|
|
Tuple[`np.array`, Dict]: A tuple containing: |
|
|
- List[PIL.Image] of frames in RGB. |
|
|
- Metadata dictionary. |
|
|
""" |
|
|
output_kwargs = self._merge_kwargs( |
|
|
HCXProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
logger.warning_once(f"num_frames control via argument is not supported yet. Ignored num_frames: {num_frames}.") |
|
|
logger.warning_once(f"fps control via argument is not supported yet. Ignored fps: {fps}.") |
|
|
logger.warning_once(f"backend control via argument is not supported yet. Ignored backend: {backend}.") |
|
|
|
|
|
def _hcx_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs): |
|
|
max_num_grids = output_kwargs["videos_kwargs"]["max_num_grids"] |
|
|
max_image_cnt = output_kwargs["videos_kwargs"]["max_image_cnt"] |
|
|
frame_indices, time_interval = extract_frame_indices( |
|
|
metadata.duration, |
|
|
metadata.total_num_frames, |
|
|
metadata.fps, |
|
|
max_num_grids, |
|
|
max_image_cnt, |
|
|
default_interval=0.4, |
|
|
) |
|
|
metadata.time_interval = time_interval |
|
|
return np.array(frame_indices) |
|
|
|
|
|
video_loaded, video_metadata = None, None |
|
|
for backend in ["decord", "pyav", "opencv", "torchvision"]: |
|
|
try: |
|
|
video_loaded, video_metadata = load_video( |
|
|
video, sample_indices_fn=_hcx_sample_indices_fn, backend=backend |
|
|
) |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading video with {backend} backend: {e}") |
|
|
continue |
|
|
|
|
|
assert video_loaded is not None, "Failed to load video with any backend" |
|
|
|
|
|
return video_loaded, video_metadata |
|
|
|
|
|
def _process_messages_for_chat_template( |
|
|
self, |
|
|
conversation: List[List[Dict[str, str]]], |
|
|
batch_images: List[List[ImageInput]], |
|
|
batch_videos: List[List[VideoInput]], |
|
|
batch_video_metadata: List[List[Dict[str, any]]], |
|
|
**mm_load_kwargs: Unpack[ChatTemplateLoadKwargs], |
|
|
): |
|
|
""" |
|
|
Overrided function. |
|
|
Used within `apply_chat_template` when a model has a special way to process conversation history. For example, |
|
|
video models might want to specify in the prompt the duration of video or which frame indices at which timestamps |
|
|
were sampled. This information cannot be accessed before the video is loaded. |
|
|
|
|
|
For most models it is a no-op, and must be overridden by model processors which require special processing. |
|
|
|
|
|
Args: |
|
|
conversation (`List[Dict, str, str]`): |
|
|
The conversation to process. Always comes in batched format. |
|
|
batch_images (`List[List[ImageInput]]`): |
|
|
Batch of images that were loaded from url/path defined in the conversation. The images |
|
|
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL` images |
|
|
per batch. |
|
|
batch_videos (`List[List[ImageInput]]`): |
|
|
Batch of videos that were loaded from url/path defined in the conversation. The videos |
|
|
are ordered in the same way as in the conversation. Comes in nested list format, one list of `PIL.Image` |
|
|
per batch. |
|
|
batch_video_metadata (`List[List[Dict[[str, any]]]]`): |
|
|
Batch of metadata returned from loading videos. That includes video fps, duration and total number of framer in original video. |
|
|
Metadata are ordered in the same way as `batch_videos`. Comes in nested list format, one list of `Dict` |
|
|
per batch. |
|
|
""" |
|
|
|
|
|
is_video_in_conversation = False |
|
|
for batch_idx, messages in enumerate(conversation): |
|
|
is_video_in_messages = False |
|
|
is_image_in_messages = False |
|
|
for message in messages: |
|
|
for content in message["content"]: |
|
|
if content["type"] == "video": |
|
|
is_video_in_messages = True |
|
|
elif content["type"] == "image": |
|
|
is_image_in_messages = True |
|
|
if not is_video_in_messages: |
|
|
batch_videos.insert(batch_idx, []) |
|
|
batch_video_metadata.insert(batch_idx, []) |
|
|
if not is_image_in_messages: |
|
|
batch_images.insert(batch_idx, []) |
|
|
|
|
|
is_video_in_conversation = is_video_in_conversation or is_video_in_messages |
|
|
|
|
|
if not is_video_in_conversation: |
|
|
return conversation |
|
|
|
|
|
|
|
|
new_conversation = [] |
|
|
for batch_idx, messages in enumerate(conversation): |
|
|
video_counter = 0 |
|
|
new_messages = [] |
|
|
|
|
|
for message in messages: |
|
|
new_message = { |
|
|
"role": message["role"], |
|
|
"content": [], |
|
|
} |
|
|
for content in message["content"]: |
|
|
if content["type"] == "video": |
|
|
video = batch_videos[batch_idx][video_counter] |
|
|
video_meta = batch_video_metadata[batch_idx][video_counter] |
|
|
|
|
|
time_stamps = calc_timestamp_video_grids(video, video_meta.time_interval, max_grid_shape=(3, 3)) |
|
|
video_counter += 1 |
|
|
|
|
|
if "filename" in content: |
|
|
filename = content["filename"] |
|
|
else: |
|
|
filename = content["video"].split("/")[-1] |
|
|
if len(filename) > 50: |
|
|
filename = f"{uuid.uuid4().hex}.mp4" |
|
|
basename, ext = os.path.splitext(filename) |
|
|
if ext == "": |
|
|
ext = ".mp4" |
|
|
|
|
|
for frame_idx, time_stamp in enumerate(time_stamps): |
|
|
if frame_idx == len(video) - 1: |
|
|
|
|
|
new_content = { |
|
|
"filename": f"{basename}-{frame_idx}{ext}", |
|
|
"video": content["video"], |
|
|
"type": "video", |
|
|
"video_time_stamp": time_stamp, |
|
|
"lens_keywords": content["lens_keywords"], |
|
|
"lens_local_keywords": content["lens_local_keywords"], |
|
|
"speech_to_text": content["speech_to_text"], |
|
|
"is_final_grid": True, |
|
|
} |
|
|
new_message["content"].append(new_content) |
|
|
else: |
|
|
new_content = { |
|
|
"filename": f"{basename}-{frame_idx}{ext}", |
|
|
"video": content["video"], |
|
|
"type": "video", |
|
|
"video_time_stamp": time_stamp, |
|
|
} |
|
|
new_message["content"].append(new_content) |
|
|
else: |
|
|
new_message["content"].append(copy.deepcopy(content)) |
|
|
new_messages.append(new_message) |
|
|
new_conversation.append(new_messages) |
|
|
|
|
|
return new_conversation |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
text: TextInput = None, |
|
|
images: List[List[ImageInput]] = None, |
|
|
videos: List[List[VideoInput]] = None, |
|
|
audio: AudioInput = None, |
|
|
**kwargs: Unpack[HCXProcessorKwargs], |
|
|
): |
|
|
output_kwargs = self._merge_kwargs( |
|
|
HCXProcessorKwargs, |
|
|
tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
mm_inputs = { |
|
|
"pixel_values_images": [], |
|
|
"image_sizes_images": [], |
|
|
"vision_query_lengths_images": [], |
|
|
"pixel_values_videos": [], |
|
|
|
|
|
"vision_query_lengths_videos": [], |
|
|
} |
|
|
calc_non_vision_query_lengths = output_kwargs["text_kwargs"].pop("calc_non_vision_query_lengths") |
|
|
if calc_non_vision_query_lengths: |
|
|
mm_inputs["non_vision_query_lengths"] = [] |
|
|
|
|
|
|
|
|
if videos is not None: |
|
|
vit_input_size = self.image_processor.crop_size["width"] |
|
|
|
|
|
video_kwargs = copy.deepcopy(output_kwargs["videos_kwargs"]) |
|
|
|
|
|
for videos_in_single_conversation in videos: |
|
|
pixel_values_videos = [] |
|
|
vision_query_lengths_videos = [] |
|
|
|
|
|
for video_frames in videos_in_single_conversation: |
|
|
if len(video_frames) == 0: |
|
|
mm_inputs["pixel_values_videos"].append([]) |
|
|
mm_inputs["vision_query_lengths_videos"].append([]) |
|
|
continue |
|
|
video_frames_combined = combine_frames_into_images( |
|
|
video_frames, max_grid_shape=(3, 3), vit_input_size=vit_input_size |
|
|
) |
|
|
video_kwargs["is_video"] = True |
|
|
video_kwargs["return_tensors"] = None |
|
|
|
|
|
frames_processed = self.image_processor(images=video_frames_combined, **video_kwargs) |
|
|
sizes = [(size["width"], size["height"]) for size in frames_processed["image_sizes"]] |
|
|
|
|
|
pixel_values_videos.extend(frames_processed["pixel_values"]) |
|
|
vision_query_lengths_videos.extend(frames_processed["vision_query_lengths"]) |
|
|
|
|
|
mm_inputs["pixel_values_videos"].append(pixel_values_videos) |
|
|
mm_inputs["vision_query_lengths_videos"].append(vision_query_lengths_videos) |
|
|
|
|
|
|
|
|
if images is not None: |
|
|
image_kwargs = copy.deepcopy(output_kwargs["images_kwargs"]) |
|
|
image_kwargs["is_video"] = False |
|
|
image_kwargs["return_tensors"] = None |
|
|
|
|
|
for images_in_single_conversation in images: |
|
|
if isinstance(images_in_single_conversation, PIL.Image.Image): |
|
|
images_in_single_conversation = [images_in_single_conversation, ] |
|
|
if len(images_in_single_conversation) == 0: |
|
|
mm_inputs["pixel_values_images"].append([]) |
|
|
mm_inputs["image_sizes_images"].append([]) |
|
|
mm_inputs["vision_query_lengths_images"].append([]) |
|
|
continue |
|
|
images_processed = self.image_processor(images=images_in_single_conversation, **image_kwargs) |
|
|
sizes = [(size["width"], size["height"]) for size in images_processed["image_sizes"]] |
|
|
|
|
|
mm_inputs["pixel_values_images"].append(images_processed["pixel_values"]) |
|
|
mm_inputs["image_sizes_images"].append(sizes) |
|
|
mm_inputs["vision_query_lengths_images"].append(images_processed["vision_query_lengths"]) |
|
|
|
|
|
|
|
|
def _create_replacer(_target_token, _replacements): |
|
|
_iterator = iter(_replacements) |
|
|
|
|
|
def _replacer(match_obj): |
|
|
|
|
|
num_query_tokens = next(_iterator) |
|
|
return "".join( |
|
|
[_target_token for _ in range(num_query_tokens)] |
|
|
) |
|
|
|
|
|
return _replacer |
|
|
|
|
|
text_inputs = {} |
|
|
if text is not None: |
|
|
if not isinstance(text, list): |
|
|
text = [text] |
|
|
|
|
|
if images is not None: |
|
|
new_texts = [] |
|
|
for batch_idx, text_in_single_conversation in enumerate(text): |
|
|
new_text = self.image_token_pattern.sub( |
|
|
_create_replacer(self.image_token, mm_inputs["vision_query_lengths_images"][batch_idx]), |
|
|
text_in_single_conversation, |
|
|
) |
|
|
new_texts.append(new_text) |
|
|
text = new_texts |
|
|
|
|
|
if videos is not None: |
|
|
new_texts = [] |
|
|
for batch_idx, text_in_single_conversation in enumerate(text): |
|
|
new_text = self.video_token_pattern.sub( |
|
|
_create_replacer(self.video_token, mm_inputs["vision_query_lengths_videos"][batch_idx]), |
|
|
text_in_single_conversation, |
|
|
) |
|
|
new_texts.append(new_text) |
|
|
text = new_texts |
|
|
|
|
|
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) |
|
|
|
|
|
|
|
|
if audio is not None: |
|
|
raise NotImplementedError("Audio processing is not supported yet.") |
|
|
|
|
|
return HCXBatchFeature(data={**text_inputs, **mm_inputs}) |
|
|
|
|
|
def decode(self, *args, **kwargs): |
|
|
""" |
|
|
This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
|
the docstring of this method for more information. |
|
|
""" |
|
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
|
|
def batch_decode(self, *args, **kwargs): |
|
|
""" |
|
|
This method forwards all its arguments to Siglip2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
|
refer to the docstring of this method for more information. |
|
|
""" |
|
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
|
|
def post_process_image_text_to_text( |
|
|
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs |
|
|
): |
|
|
""" |
|
|
Post-process the output of the model to decode the text. |
|
|
|
|
|
Args: |
|
|
generated_outputs (`torch.Tensor` or `np.ndarray`): |
|
|
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` |
|
|
or `(sequence_length,)`. |
|
|
skip_special_tokens (`bool`, *optional*, defaults to `True`): |
|
|
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. |
|
|
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): |
|
|
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. |
|
|
**kwargs: |
|
|
Additional arguments to be passed to the tokenizer's `batch_decode method`. |
|
|
|
|
|
Returns: |
|
|
`List[str]`: The decoded text. |
|
|
""" |
|
|
return self.tokenizer.batch_decode( |
|
|
generated_outputs, |
|
|
skip_special_tokens=skip_special_tokens, |
|
|
clean_up_tokenization_spaces=clean_up_tokenization_spaces, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
@property |
|
|
def model_input_names(self): |
|
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
|
image_processor_input_names = self.image_processor.model_input_names |
|
|
names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
return names_from_processor + [] |
|
|
|
|
|
|
|
|
def extract_frame_indices(play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=0.4): |
|
|
""" |
|
|
Extracts specific frame indices from a video based on duration, frame count, and sampling strategy. |
|
|
|
|
|
The function determines which frames to extract given the video duration (`play_time`), |
|
|
total frame count, and frame rate. It samples frames at regular intervals (default: 0.4s), |
|
|
but if the number of frames exceeds the limit defined by `max_num_grids * max_image_cnt`, |
|
|
it performs uniform sampling to stay within that limit. |
|
|
|
|
|
Args: |
|
|
play_time (float): Total play time of the video in seconds. |
|
|
total_frames (int): Total number of frames in the video. |
|
|
fps (float): Frames per second of the video. |
|
|
max_num_grids (int): Maximum number of grids to display. |
|
|
max_image_cnt (int): Maximum number of images per grid. |
|
|
default_interval (float, optional): Interval in seconds between frame samples. Defaults to 0.4. |
|
|
|
|
|
Returns: |
|
|
Tuple: |
|
|
frame_indices (List[int]): A list of selected frame indices. |
|
|
time_interval (float): Time interval between selected frames (in seconds). |
|
|
""" |
|
|
|
|
|
|
|
|
default_frame_count = int(play_time / default_interval) |
|
|
|
|
|
|
|
|
max_frames_allowed = max_num_grids * max_image_cnt |
|
|
|
|
|
|
|
|
if default_frame_count <= max_frames_allowed: |
|
|
|
|
|
frame_interval = int(total_frames / default_frame_count) |
|
|
else: |
|
|
|
|
|
frame_interval = int(total_frames / max_frames_allowed) |
|
|
|
|
|
|
|
|
selected_indices = list(range(0, total_frames, frame_interval)) |
|
|
|
|
|
time_interval = frame_interval / fps |
|
|
|
|
|
|
|
|
return selected_indices[:max_frames_allowed], time_interval |
|
|
|
|
|
|
|
|
def calc_timestamp_video_grids(frames, time_interval, max_grid_shape=(3, 3)): |
|
|
""" |
|
|
Calculates the time range labels for each grid in a video. |
|
|
|
|
|
Args: |
|
|
frames (List[PIL.Image.Image]): A list of frames extracted from a video. |
|
|
time_interval (float): Time interval (in seconds) between consecutive frames. |
|
|
max_grid_shape (Tuple[int, int], optional): The maximum grid shape as (rows, cols). Defaults to (3, 3). |
|
|
vit_input_size (int, optional): The target size (height and width) for the Vision Transformer input. Defaults to 378. |
|
|
|
|
|
Returns: |
|
|
Tuple: |
|
|
image_time_stamps (List[str]): A list of time span labels for each combined image, |
|
|
e.g., ["0.00s~1.50s", "1.50s~3.00s", ...]. |
|
|
""" |
|
|
max_num_grids = max_grid_shape[0] * max_grid_shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_frames = len(frames) |
|
|
num_canvases = num_frames // max_num_grids |
|
|
leftover_frames = num_frames % max_num_grids |
|
|
|
|
|
time_stamp = 0 |
|
|
image_time_stamps = [] |
|
|
|
|
|
for canvas_idx in range(num_canvases): |
|
|
|
|
|
start_idx = canvas_idx * max_num_grids |
|
|
end_idx = min(start_idx + max_num_grids, num_frames) |
|
|
|
|
|
|
|
|
frame_cnt = end_idx - start_idx |
|
|
image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s") |
|
|
time_stamp += frame_cnt * time_interval |
|
|
|
|
|
if leftover_frames > 0: |
|
|
|
|
|
frame_cnt = leftover_frames |
|
|
image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s") |
|
|
time_stamp += frame_cnt * time_interval |
|
|
|
|
|
return image_time_stamps |
|
|
|
|
|
|
|
|
def combine_frames_into_images(frames, max_grid_shape=(3, 3), vit_input_size=378): |
|
|
""" |
|
|
Combines a sequence of video frames into grid-based images and generates corresponding time range labels. |
|
|
|
|
|
Frames are grouped and arranged into a grid (e.g., 3x3) such that each combined image contains up to |
|
|
`max_grid_shape[0] * max_grid_shape[1]` frames. Each combined image is resized to the given ViT input size. |
|
|
|
|
|
Args: |
|
|
frames (NDArray): (num_frames, H, W, C) shape. A list of frames extracted from a video. |
|
|
time_interval (float): Time interval (in seconds) between consecutive frames. |
|
|
max_grid_shape (Tuple[int, int], optional): The maximum grid shape as (rows, cols). Defaults to (3, 3). |
|
|
vit_input_size (int, optional): The target size (height and width) for the Vision Transformer input. Defaults to 378. |
|
|
|
|
|
Returns: |
|
|
Tuple: |
|
|
image_list (List[PIL.Image.Image]): A list of grid-combined images. |
|
|
""" |
|
|
max_num_grids = max_grid_shape[0] * max_grid_shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_list = [] |
|
|
|
|
|
|
|
|
num_frames = len(frames) |
|
|
num_canvases = num_frames // max_num_grids |
|
|
leftover_frames = num_frames % max_num_grids |
|
|
|
|
|
|
|
|
frames = [Image.fromarray(frame) for frame in frames] |
|
|
|
|
|
for canvas_idx in range(num_canvases): |
|
|
|
|
|
combined_image = Image.new( |
|
|
"RGB", (vit_input_size * max_grid_shape[0], vit_input_size * max_grid_shape[1]), color=(0, 0, 0) |
|
|
) |
|
|
|
|
|
|
|
|
start_idx = canvas_idx * max_num_grids |
|
|
end_idx = min(start_idx + max_num_grids, num_frames) |
|
|
|
|
|
for idx in range(start_idx, end_idx): |
|
|
img = frames[idx] |
|
|
|
|
|
|
|
|
img_resized = img.resize((vit_input_size, vit_input_size)) |
|
|
|
|
|
|
|
|
local_idx = idx - start_idx |
|
|
x_offset = (local_idx % max_grid_shape[0]) * vit_input_size |
|
|
y_offset = (local_idx // max_grid_shape[0]) * vit_input_size |
|
|
|
|
|
|
|
|
combined_image.paste(img_resized, (x_offset, y_offset)) |
|
|
|
|
|
|
|
|
image_list.append(combined_image) |
|
|
|
|
|
if leftover_frames > 0: |
|
|
|
|
|
canvas_idx = num_canvases |
|
|
|
|
|
|
|
|
combined_image = Image.new( |
|
|
"RGB", (vit_input_size * max_grid_shape[0], vit_input_size * max_grid_shape[1]), color=(0, 0, 0) |
|
|
) |
|
|
|
|
|
for idx in range(leftover_frames): |
|
|
img = frames[num_canvases * max_num_grids + idx] |
|
|
|
|
|
|
|
|
img_resized = img.resize((vit_input_size, vit_input_size)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_offset = (idx % max_grid_shape[0]) * vit_input_size |
|
|
y_offset = (idx // max_grid_shape[0]) * vit_input_size |
|
|
|
|
|
|
|
|
combined_image.paste(img_resized, (x_offset, y_offset)) |
|
|
|
|
|
|
|
|
image_list.append(combined_image) |
|
|
|
|
|
return image_list |
|
|
|