diff --git "a/preprocessor.py" "b/preprocessor.py" new file mode 100644--- /dev/null +++ "b/preprocessor.py" @@ -0,0 +1,3285 @@ +import ast +import copy +import datetime +import gc +import io +import json +import math +import mimetypes +import os +import random +import re +import sys +import tarfile +import tempfile +import zipfile +from collections import defaultdict, deque +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import av +import cv2 +import numpy as np +import PIL +import pkg_resources +import scipy.signal as scsig +import torch +from decord import VideoReader, cpu +from PIL import Image, ImageDraw +from smart_open import open +from torchvision.transforms.functional import to_tensor + +from hcxvlm.dataset.base_dataset import image_decoder +from hcxvlm.dataset.hcx_vision_prompter import HCXVisionPrompter + +CHOICES = list(map(chr, range(97, 123))) +IGNORE_INDEX = -100 +DEFAULT_SAMPLE_RATE = 16000 +MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600 +DEFAULT_VOLUME_LEVEL = 10 ** (-26 / 20) + +hcx_vision_prompter = HCXVisionPrompter() + + +def hpf_normalize( + wav: np.ndarray, + sr: int = DEFAULT_SAMPLE_RATE, + volume_level: float = DEFAULT_VOLUME_LEVEL, +) -> np.ndarray: + assert (wav**2).mean() > 0, "Error in the wav file" + + filter_ = scsig.butter(2, 70, "highpass", fs=sr, output="sos") + wav = scsig.sosfilt(filter_, wav) + wav = wav.astype(np.float32) + + gain = min(volume_level / (wav**2).mean() ** 0.5, 1 / np.max(np.abs(wav))) + wav *= gain + return wav + + +def convert_bboxes(img, img_meta): + for k, v in img_meta.items(): + if k == "region": + bbox_key = "bbox" if "bbox" in img_meta[k] else "boundingBox" + img_meta[k] = reform_bbox( + img_meta[k][bbox_key], img.size, format=img_meta[k]["format"] + ) + return img_meta + + +def reform_bbox(bbox, image_size, format="REL_XYXY"): + w, h = image_size + if format == "REL_XYXY": + x1, y1, x2, y2 = bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h + elif format == "REL_XYWH": + x1, y1 = bbox[0] * w, bbox[1] * h + x2, y2 = x1 + bbox[2] * w, y1 + bbox[3] * h + else: + raise NotImplementedError + new_bbox = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] + return new_bbox + + +def generate_random_color(use_alpha=True, seed=None): + if seed is None: + seed = np.random.default_rng() + + if use_alpha: + color_list = [ + ("빨강", (255, 127, 127, 100)), + ("노랑", (255, 255, 127, 100)), + ("초록", (127, 255, 125, 100)), + ("하늘", (127, 255, 255, 100)), + ("파랑", (127, 127, 255, 100)), + ("보라", (255, 127, 255, 100)), + ] + else: + color_list = [ + ("빨강", (255, 0, 0)), + ("노랑", (255, 255, 0)), + ("초록", (0, 128, 0)), + ("하늘", (135, 206, 235)), + ("파랑", (0, 0, 255)), + ("보라", (128, 0, 128)), + ] + return color_list[seed.integers(0, len(color_list))] + + +EN_COLOR = { + "빨강": "red", + "노랑": "yellow", + "초록": "green", + "하늘": "sky blue", + "파랑": "blue", + "보라": "purple", +} + + +def overlay_rectangle(image, words, lang, seed=None): + color_str, color = generate_random_color(seed=seed) + draw = ImageDraw.Draw(image, "RGBA") + for word in words: + shape_rect = word["bbox"] + shape_rect = [(round(x[0]), round(x[1])) for x in shape_rect] + draw.polygon(shape_rect, color) + del draw + if lang == "en": + color_str = EN_COLOR[color_str] + return image, color_str + + +def convert_tags_for_video(img, json): + """video 데이터에는 태그 대신 tag가 있음. + img 숫자 만큼 tag 대신 tag를 변환하여 넣음 + """ + image_tag = "".join([f"" for idx in range(len(img))]) + for json_key in json: + if "qa_pairs" in json_key: + new_qa_pairs = [] + for qa_pair in json[json_key]: + question = qa_pair[0] + question = question.replace("", image_tag) + new_qa_pairs.append([question, qa_pair[1]]) + json[json_key] = new_qa_pairs + + return img, json + + +def sampling_multiturn_single_img( + seq, + count, + multiturn_preserve_order=True, + multiturn_continuous=False, + is_train: bool = True, + seed=None, +): + if seed is None: + seed = np.random.default_rng() + n_sample = min(count, len(seq)) + + if multiturn_continuous: + if len(seq) <= n_sample: + start_index = 0 + else: + start_index = seed.integers(0, len(seq) - n_sample) + indices = range(start_index, start_index + n_sample) + elif multiturn_preserve_order: + indices = sorted(seed.choice(range(len(seq)), size=n_sample, replace=False)) + else: + indices = seed.choice(range(len(seq)), size=n_sample, replace=False) + + return [seq[i] for i in indices] + + +def draw_bbox(image, bbox, lang="en", line_width=5, seed=None): + if seed is None: + seed = np.random.default_rng() + color_str, color = generate_random_color(use_alpha=False, seed=seed) + draw = ImageDraw.Draw(image, "RGB") + rect_bbox = (bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]) + draw.rectangle(rect_bbox, outline=color, width=line_width) + del draw + if lang == "en": + color_str = EN_COLOR[color_str] + return image, color_str + + +def bbox_process(bbox, detection_precision=2): + bbox_str = "[" + for idx, point in enumerate(bbox): + if idx % 2 == 0: + normalized = point + else: + normalized = point + + if idx < len(bbox) - 1: + bbox_str += format(normalized, ".2f") + ", " + else: + bbox_str += format(normalized, ".2f") + bbox_str += "]" + return bbox_str + + +def load_txt(file_path): + lines_list = [] + with open(file_path, "r") as file: + for line in file: + lines_list.append(line.replace("\\n", "\n").strip()) + return lines_list + + +def convert_format_for_multi_image( + img, json, convert_key_list=["words", "text", "objects", "entities"] +): + """single image dataset 과 multi image dataset 에서 읽어온 img, json format 이 다름. + 따라서 single image dataset 에서 읽어온 img, json 을 multi image dataset 의 format (dict) 으로 convert + """ + is_multi_image_dataset = isinstance(img, dict) + if not is_multi_image_dataset: + img = {"00": img} + + for convert_key in convert_key_list: + if convert_key in json: + json[convert_key] = {"00": json[convert_key]} + + for json_key in json: + if "region" in json_key: + json[json_key] = {"00": json[json_key]} + else: + for convert_key in convert_key_list: + if convert_key in json: + if isinstance(json[convert_key], list): + json[convert_key] = {"00": json[convert_key]} + + for json_key in json: + if "region" in json_key: + if isinstance(json[json_key], list): + json[json_key] = {"00": json[json_key]} + + return is_multi_image_dataset, img, json + + +class ConditionalError(Exception): + def __init__(self, message="Our assertion error"): + super().__init__(message) + + +def get_wds_default_config(default_config, existing_default_config=None): + if existing_default_config is None: + default_config_check_dict = { + "subtask": "", + "reasoning": False, + "use_task_prompt": True, + "get_random": True, + "add_instruct_prompts": [], + "multiturn_n_samples": 0, + "multiturn_preserve_order": True, + "multiturn_continuous": False, + "insert_ocr": 200, + "ocr_filter_strategy": "confidence", + "ocr_use_ratio": 1.0, + "entity_top_k": 0, + "entity_keyword_threshold": 100, + "entity_keyword_fashion_threshold": 100, + "entity_use_ratio": 0.0, + "llava_pretrain": False, + "random_system_prob": 0.0, + "random_system_path": "", + "random_tool_prob": 0.005, + } + else: + default_config_check_dict = existing_default_config + if default_config is None: + default_config = default_config_check_dict + else: + for key, value in default_config_check_dict.items(): + if key not in default_config: + default_config[key] = value + return default_config + + +def get_datalake_default_config(default_config): + default_config_check_dict = { + "multiturn_n_samples": 0, + "multiturn_preserve_order": True, + "multiturn_continuous": True, + "insert_ocr": 0, + "ocr_filter_strategy": "confidence", + "entity_top_k": 0, + "entity_keyword_threshold": 0, + "entity_keyword_fashion_threshold": 0, + "entity_use_ratio": 0.0, + "ocr_use_ratio": 0.0, + "llava_pretrain": False, + "random_system_prob": 0.0, + "random_system_path": "", + "random_tool_prob": 0.005, + } + if default_config is None: + default_config = default_config_check_dict + else: + for key, value in default_config_check_dict.items(): + if key not in default_config: + default_config[key] = value + return default_config + + +@dataclass +class Processed_sample: + input_str: str = None + input_ids: list = None + label_ids: list = None + imgs: list = None + discrete_imgs: list = None + videos: list = None + videos_duration: List[dict] = None + video_audios: list = None + audios: list = None + audios_duration: List[dict] = None + discrete_audios: list = None + sample_mm_counter: dict = None + + +from hcxvlm.dataset.bbox_processor import ( + extract_bboxes, + insert_bboxes_to_json, + is_bbox_padded, +) + + +class Preprocessor: + prompt_head = "" + va_prefix = "\n<|im_start|>" + new_line = "\n" + turn_prefix = "<|im_start|>" + turn_suffix = "<|im_end|>" + mime_start = "<|mime_start|>" + mime_end = "<|mime_end|>" + aux_img_start = "<|image_aux_start|>" + aux_img_end = "<|image_aux_end|>" + aux_video_start = "<|video_aux_start|>" + aux_video_end = "<|video_aux_end|>" + aux_audio_start = "<|audio_aux_start|>" + aux_audio_end = "<|audio_aux_end|>" + image_start = "<|image_start|>" + image_end = "<|image_end|>" + image_pad = "<|IMAGE_PAD|>" + video_start = "<|video_start|>" + video_end = "<|video_end|>" + video_pad = "<|VIDEO_PAD|>" + audio_start = "<|audio_start|>" + audio_end = "<|audio_end|>" + audio_pad = "<|AUDIO_PAD|>" + discrete_image_start = "<|discrete_image_start|>" + discrete_image_end = "<|discrete_image_end|>" + discrete_image_pad = "<|DISCRETE_IMAGE_PAD|>" + video_audio_pad = "<|VIDEO_AUDIO_PAD|>" + discrete_audio_start = "<|discrete_audio_start|>" + discrete_audio_end = "<|discrete_audio_end|>" + discrete_audio_pad = "<|DISCRETE_AUDIO_PAD|>" + + discrete_image_eol = "<|vision_eol|>" + discrete_image_eof = "<|vision_eof|>" + discrete_image_ratios = { + (1, 1): "<|vision_ratio_1:1|>", + (1, 2): "<|vision_ratio_1:2|>", + (2, 1): "<|vision_ratio_2:1|>", + (3, 4): "<|vision_ratio_3:4|>", + (4, 3): "<|vision_ratio_4:3|>", + (3, 5): "<|vision_ratio_3:5|>", + (5, 3): "<|vision_ratio_5:3|>", + (4, 5): "<|vision_ratio_4:5|>", + (5, 4): "<|vision_ratio_5:4|>", + (6, 9): "<|vision_ratio_6:9|>", + (9, 6): "<|vision_ratio_9:6|>", + (9, 16): "<|vision_ratio_9:16|>", + (16, 9): "<|vision_ratio_16:9|>", + } + + aux_vid_prompt = ( + "다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. " + ) + aux_audio_prompt = ( + "다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. " + ) + + def __init__( + self, + tokenizer=None, + prepare_input_fn=None, + prepare_audio_input_fn=None, + sample_min_length=0, + decoder_max_length=None, + mode="train", + model=None, + datalake_default_config=None, + wds_default_config=None, + video_config=None, + train_video=False, + train_audio=False, + sequence_parallel_size=1, + video_audio_compressor_type=None, + ): + self.sequence_parallel_size = sequence_parallel_size + if sequence_parallel_size > 1: + self.rng = np.random.default_rng(seed=42) + else: + self.rng = np.random.default_rng() + + if model is not None: + tokenizer = model.tokenizer + decoder_max_length = 16000 + + if model is not None and prepare_input_fn is None: + raise "please give ImageProcessor!" + + self.prepare_input_fn = prepare_input_fn + self.prepare_audio_input_fn = prepare_audio_input_fn + try: + from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import ( + Qwen2_5_VLProcessor, + ) + + self.is_qwen_visual = isinstance(prepare_input_fn, Qwen2_5_VLProcessor) + except Exception as e: + self.is_qwen_visual = False + try: + if not self.is_qwen_visual: + from hcxvlm.models.processing_vlm import HCXVisionV2Processor + + self.is_qwen_visual = isinstance(prepare_input_fn, HCXVisionV2Processor) + except Exception as e: + self.is_qwen_visual = False + assert self.is_qwen_visual, "qwen2.5-vl visual prepare_input_fn import error" + + self.video_max_num_frames = ( + video_config["video_max_num_frames"] + if video_config and "video_max_num_frames" in video_config + else 120 + ) + self.video_max_pixels = ( + video_config["video_max_pixels"] + if video_config and "video_max_pixels" in video_config + else 378 * 378 + ) + + self.tokenizer = tokenizer + self.sample_min_length = sample_min_length + self.decoder_max_length = decoder_max_length + self.mode = mode + self.default_config = get_datalake_default_config(datalake_default_config) + self.wds_default_config = get_wds_default_config(wds_default_config) + self.train_video = train_video + self.train_audio = train_audio + self.video_audio_compressor_type = video_audio_compressor_type + + self.img_token = self.tokenizer.encode(Preprocessor.image_pad)[0] + assert ( + len(self.tokenizer.encode(Preprocessor.image_pad)) == 1 + ), "img_token is not configured in tokenizer" + + self.discrete_image_token = self.tokenizer.encode( + Preprocessor.discrete_image_pad + )[0] + assert ( + len(self.tokenizer.encode(Preprocessor.discrete_image_pad)) == 1 + ), "discrete_image_token is not configured in tokenizer" + + self.discrete_image_eol_token = self.tokenizer.encode( + Preprocessor.discrete_image_eol + )[0] + assert ( + len(self.tokenizer.encode(Preprocessor.discrete_image_eol)) == 1 + ), "discrete_image_eol_token is not configured in tokenizer" + + self.discrete_image_eof_token = self.tokenizer.encode( + Preprocessor.discrete_image_eof + )[0] + assert ( + len(self.tokenizer.encode(Preprocessor.discrete_image_eof)) == 1 + ), "discrete_image_eof_token is not configured in tokenizer" + + self.discrete_image_ratio_tokens = dict() + for ratio, token_str in Preprocessor.discrete_image_ratios.items(): + token_id = self.tokenizer.encode(token_str)[0] + assert ( + len(self.tokenizer.encode(token_str)) == 1 + ), f"discrete_image_ratio_token {token_str} is not configured in tokenizer" + self.discrete_image_ratio_tokens[ratio] = token_id + + self.video_token = self.tokenizer.encode(Preprocessor.video_pad)[0] + assert ( + len(self.tokenizer.encode(Preprocessor.video_pad)) == 1 + ), "video_token is not configured in tokenizer" + + self.video_audio_token = self.tokenizer.encode(Preprocessor.video_audio_pad)[0] + assert ( + len(self.tokenizer.encode(Preprocessor.video_audio_pad)) == 1 + ), "video_audio_token is not configured in tokenizer" + + def resize_min_edge(img: Image.Image) -> Image.Image: + w, h = img.size + min_size = 28 + if min(w, h) >= min_size: + return img + if w < h: + new_w = min_size + new_h = int(h * (min_size / w)) + else: + new_h = min_size + new_w = int(w * (min_size / h)) + return img.resize((new_w, new_h), Image.BICUBIC) + + self._resize_min_edge = resize_min_edge + + self.audio_token = self.tokenizer.encode(Preprocessor.audio_pad)[0] + assert ( + len(self.tokenizer.encode(Preprocessor.audio_pad)) == 1 + ), "audio_token is not configured in tokenizer" + + self.discrete_audio_token = self.tokenizer.encode( + Preprocessor.discrete_audio_pad + )[0] + assert ( + len(self.tokenizer.encode(Preprocessor.discrete_audio_pad)) == 1 + ), "audio_token is not configured in tokenizer" + + from hcxvlm.dataset.json_processer import generate_prompt + + self.generate_prompt = generate_prompt + + self.mimes = list() + for mime_filename in [ + "words_alpha.txt", + "korean-366506-wordslistUnique.txt", + ]: + self.mimes += ( + pkg_resources.resource_string( + "hcxvlm", f"dataset/hcx_vision_prompter/prompts/{mime_filename}" + ) + .decode("utf-8") + .split("\r\n") + ) + + self.common_tools = [] + try: + common_tools_bytes = pkg_resources.resource_string( + "hcxvlm", + "dataset/hcx_vision_prompter/prompts/common_tools.jsonl", + ) + for line in common_tools_bytes.decode("utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + self.common_tools.append(json.loads(line)) + except Exception: + continue + except Exception: + self.common_tools = [] + + self.random_system_prompt = "" + if self.default_config["random_system_path"] != "": + self.random_system_prompt = "" + with open(self.default_config["random_system_path"], "r") as f: + for line in f: + self.random_system_prompt += line + + if ( + self.random_system_prompt != "" + and self.wds_default_config["random_system_path"] != "" + ): + assert ( + self.wds_default_config["random_system_path"] + == self.default_config["random_system_path"] + ), "random_system_path in both default_config and wds_default_config should be the same" + + def _find_best_ratio_token(self, original_size): + """Find the best ratio token based on original_size""" + base_ratios = list(self.discrete_image_ratio_tokens.keys()) + vision_aspect_ratios = [ + r for ratio in base_ratios for r in [ratio, ratio[::-1]] + ][1:] + + if not isinstance(original_size, list) or len(original_size) != 2: + return self.discrete_image_ratio_tokens[(1, 1)] + + h, w = original_size + if h == 0 or w == 0: + return self.discrete_image_ratio_tokens[(1, 1)] + + ratios = [i / j for i, j in vision_aspect_ratios] + + best_size_idx = np.argmin([abs(w / h - r) for r in ratios]) + + i, j = vision_aspect_ratios[best_size_idx] + return self.discrete_image_ratio_tokens[(i, j)] + + @classmethod + def prompt_mime( + cls, + mimes: Optional[list[str]] = None, + file_name: str = None, + tag_idx: int = 1, + fixed_mime: bool = False, + is_video: bool = False, + is_audio: bool = False, + seed: np.random.Generator = None, + ) -> list[dict]: + assert mimes or file_name + + if seed is None: + seed = np.random.default_rng() + + if file_name: + name, ext = os.path.splitext(file_name) + ext = ext.lstrip(".") + elif fixed_mime: + ext = "jpeg" + name = mimes[tag_idx] + elif not fixed_mime and seed is not None: + ext = seed.choice(["png", "jpeg"]) + name = mimes[seed.integers(0, len(mimes))] + else: + ext = "jpeg" + name = mimes[tag_idx] + + if is_video: + ext_candidates = ["mp4", "mov", "avi", "webm"] + if fixed_mime: + ext = "mp4" + elif ext not in ext_candidates: + ext = seed.choice(ext_candidates) + + filename = f"{name}.{ext}" + mime_type = mimetypes.guess_type(filename)[0] + mime_prompt = { + "id": f"video_{str(tag_idx).zfill(2)}", + "type": f"{mime_type}", + "filename": f"{filename}", + } + return mime_prompt + + if is_audio: + ext_candidates = ["mp3", "wav", "aac", "flac", "pcm"] + if fixed_mime: + ext = "wav" + elif ext not in ext_candidates: + ext = seed.choice(ext_candidates) + + filename = f"{name}.{ext}" + mime_type = mimetypes.guess_type(filename)[0] + mime_prompt = { + "id": f"audio_{str(tag_idx).zfill(2)}", + "type": f"{mime_type}", + "filename": f"{filename}", + } + return mime_prompt + + if file_name: + filename = f"{name}.{ext}" + mime_type = mimetypes.guess_type(filename)[0] + mime_prompt = { + "id": f"image_{str(tag_idx).zfill(2)}", + "type": f"{mime_type}", + "filename": f"{filename}", + } + else: + mime_prompt = { + "id": f"image_{str(tag_idx).zfill(2)}", + "type": f"image/{ext}", + "filename": f"{name}.{'jpg' if ext == 'jpeg' else 'png'}", + } + return mime_prompt + + @classmethod + def ocr_preprocess( + cls, + words: list[dict], + n_insert_ocr_tokens: int = 2000, + insert_ocr: int = 200, + ocr_use_ratio: float = 0.5, + tokenizer=None, + seed=None, + ) -> list[str]: + if seed is None: + seed = np.random.default_rng() + if ocr_use_ratio < seed.random(): + return None + if insert_ocr == 0: + return None + + confidence_list = [] + insert_ocr_prompt = [] + for word in words: + if "confidence" in word: + confidence_list.append(word["confidence"]) + has_ocr_confidence = len(confidence_list) >= insert_ocr + + if len(words) <= insert_ocr or not has_ocr_confidence: + insert_ocr_prompt += [ + d["text"].strip() for d in words if d["text"].strip() + ][:insert_ocr] + else: + confidence_threshold = 0.3 + cnt = 0 + for word in words: + if word["text"] == "": + continue + if word["confidence"] >= confidence_threshold: + insert_ocr_prompt.append(word["text"]) + cnt += 1 + if cnt >= insert_ocr: + break + ocr_inputs = " ".join(insert_ocr_prompt) + if tokenizer: + ocr_inputs = tokenizer.decode( + tokenizer.encode(ocr_inputs)[:n_insert_ocr_tokens] + ) + return ocr_inputs + + @classmethod + def lens_preprocess( + cls, + lens: list[dict], + entity_top_k: int = 100, + entity_keyword_threshold: float = 0.0, + entity_keyword_fashion_threshold: float = 0.0, + entity_use_ratio: float = 0.0, + seed=None, + ): + if seed is None: + seed = np.random.default_rng() + if seed.uniform(0, 1) > entity_use_ratio: + return None + + entities = lens + filter_idx = [] + insert_entity_prompt = {} + for idx, entity in enumerate(entities): + if entity["type"] != "naver_lens_api": + filter_idx.append(idx) + continue + if ( + isinstance(entity_keyword_threshold, (int, float)) + and entity["confidence"] < entity_keyword_threshold + ): + filter_idx.append(idx) + continue + if ( + isinstance(entity_keyword_fashion_threshold, (int, float)) + and ("fashion" in entity["info"]["classes"]) + and entity["confidence"] < entity_keyword_fashion_threshold + ): + filter_idx.append(idx) + continue + + entityvalue = [ + keyword for idx, keyword in enumerate(entities) if idx not in filter_idx + ] + entityvalue = sorted(entityvalue, key=lambda x: x["confidence"], reverse=True) + + important_entity_list = [] + local_entity_str_list = [] + keywords_and_bbox_per_detector = {} + for keyword_dict in entityvalue[:entity_top_k]: + object_class = "/".join(keyword_dict["info"]["classes"]) + if object_class not in keywords_and_bbox_per_detector.keys(): + keywords_and_bbox_per_detector[object_class] = [] + keywords_and_bbox_per_detector[object_class].append(keyword_dict) + + for object_class in keywords_and_bbox_per_detector.keys(): + entities_per_object = keywords_and_bbox_per_detector[object_class] + normalized_bbox = bbox_process( + [*entities_per_object[0]["bbox"][0], *entities_per_object[0]["bbox"][2]] + ) + entities = [entity["text"] for entity in entities_per_object] + if "context" in object_class: + important_entity_list += entities + + else: + local_entity_str_list += [ + str(normalized_bbox) + " " + ", ".join(entities) + ] + if len(important_entity_list) > 0: + insert_entity_prompt["lens_keywords"] = ", ".join(important_entity_list) + if len(local_entity_str_list) > 0: + insert_entity_prompt["lens_local_keywords"] = " ".join( + local_entity_str_list + ) + + return insert_entity_prompt + + @classmethod + def prompt_toollist( + cls, + output, + tokenizer=None, + turn: Optional[dict] = None, + content: Optional[list[dict]] = None, + ): + assert content or turn + if turn is None: + turn = { + "role": "tool_list", + "content": content, + } + + toollist_str = ( + cls.turn_prefix.strip() + + turn["role"] + + "\n" + + turn["content"] + + cls.turn_suffix + ) + + if hasattr(output, "input_str"): + output.input_str += toollist_str + + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(toollist_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + return output + + @classmethod + def prompt_system( + cls, + output, + tokenizer=None, + turn: Optional[dict] = None, + content: Optional[str] = None, + seed=None, + tool_prompt=None, + system_role_count=0, + ): + assert content or turn + if seed is None: + seed = np.random.default_rng() + if turn is None: + system_prompt = content + else: + if "candidates" in turn: + if len(turn["candidates"]) > 0: + system_prompt = seed.choice(turn["candidates"]) + if type(system_prompt) is dict: + system_prompt = system_prompt["content"] + else: + system_prompt = "" + elif isinstance(turn["content"], str): + system_prompt = turn["content"] + elif len(turn["content"]) > 0: + system_prompt = seed.choice(turn["content"]) + + system_str = cls.turn_prefix + turn["role"] + "\n" + system_str += system_prompt.strip() + if system_role_count == 0: + if system_prompt.strip(): + system_str += "\n" + system_str += tool_prompt + system_str += cls.turn_suffix + + if hasattr(output, "input_str"): + output.input_str += system_str + + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(system_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + return output + + @classmethod + def load_mm( + cls, + output, + img_dir: str = "", + turn: Optional[dict] = None, + image_urls: Optional[list[str]] = None, + image_metas: Optional[list[dict]] = None, + video_urls: Optional[list[str]] = None, + video_metas: Optional[list[dict]] = None, + audio_urls: Optional[list[str]] = None, + audio_metas: Optional[list[dict]] = None, + prepare_input_fn=None, + prepare_audio_input_fn=None, + max_image_cnt=21, + video_max_num_frames=None, + video_max_pixels=None, + use_audio: bool = False, + audio_sample_rate: int = 16000, + ): + assert (image_urls or video_urls or audio_urls) or turn + if turn is None: + turn = {} + if image_urls: + turn.update({"image_urls": image_urls}) + turn.update({"image_metas": image_metas}) + if video_urls: + turn.update({"video_urls": video_urls}) + turn.update({"video_metas": video_metas}) + if audio_urls: + turn.update({"audio_urls": audio_urls}) + turn.update({"audio_metas": audio_metas}) + + if "video_urls" in turn: + if len(turn["video_urls"]) and (prepare_input_fn is None): + raise ConditionalError("video processing needs 'prepare_input_fn'") + + if not isinstance(turn["content"], str): + raise ConditionalError(f"turn['content'] must be a string") + + turn["content"] = re.sub(r"", "<|image|>", turn["content"]) + pattern = re.compile( + r"<\|video\|>|<\|image\|>|<\|t2i_model_generation_target_discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>" + ) + tags = [match.group() for match in pattern.finditer(turn["content"])] + + img_idx = 0 + vid_idx = 0 + aud_idx = 0 + + if "image_urls" not in turn: + turn["image_urls"] = [] + if "video_urls" not in turn: + turn["video_urls"] = [] + if "audio_urls" not in turn: + turn["audio_urls"] = [] + + for tag in tags: + if ( + tag == "<|image|>" + or tag == "<|t2i_model_generation_target_discrete_image|>" + ): + img_path = turn["image_urls"][img_idx] + + if isinstance(img_path, str): + if "#" in img_path: + compression_path, img_path = img_path.split("#", 1) + compression_path = os.path.join(img_dir, compression_path) + assert compression_path[-4:] in [ + ".zip", + ".tar", + ], f"unsupported compression format: {compression_path}" + + with open(compression_path, "rb") as comp_file: + if compression_path.endswith(".zip"): + with zipfile.ZipFile(comp_file, "r") as zip_file: + with zip_file.open(img_path) as img_file: + img_binary = img_file.read() + elif compression_path.endswith(".tar"): + with tarfile.open( + fileobj=comp_file, mode="r" + ) as tar_file: + img_file = tar_file.extractfile(img_path) + img_binary = img_file.read() + else: + with open(os.path.join(img_dir, img_path), "rb") as f: + img_binary = f.read() + img = image_decoder(img_binary) + else: + if isinstance(img_path, (bytes, bytearray)): + img = io.BytesIO(img_path) + img = Image.open(img).convert("RGB") + else: + img = img_path + if not isinstance(img, Image.Image): + img = Image.fromarray(np.uint8(img)).convert("RGB") + + if "image_metas" in turn and turn["image_metas"]: + turn["image_metas"][img_idx] = convert_bboxes( + img, turn["image_metas"][img_idx] + ) + + if tag == "<|image|>": + output.imgs.append(img) + output.discrete_imgs.append(img) + + img_idx += 1 + elif tag == "<|video|>": + video_path = turn["video_urls"][vid_idx] + if isinstance(video_path, str): + if "#" in video_path: + compression_path, video_path = video_path.split("#", 1) + compression_path = os.path.join(img_dir, compression_path) + assert compression_path[-4:] in [ + ".zip", + ".tar", + ], f"unsupported compression format: {compression_path}" + + with open(compression_path, "rb") as comp_file: + if compression_path.endswith(".zip"): + with zipfile.ZipFile(comp_file, "r") as zip_file: + video_file = zip_file.open(video_path) + video_binary = video_file.read() + elif compression_path.endswith(".tar"): + with tarfile.open( + fileobj=comp_file, mode="r" + ) as tar_file: + video_file = tar_file.extractfile(video_path) + video_binary = video_file.read() + else: + with open(os.path.join(img_dir, video_path), "rb") as f: + video_binary = f.read() + video_binary = io.BytesIO(video_binary) + else: + video_binary = video_path + + assert isinstance(video_binary, io.BytesIO), "video binary read error" + + try: + from hcxvlm.dataset.qwen_vision_process import process_vision_info + except: + from qwen_vl_utils import process_vision_info + + if video_max_num_frames is None: + video_max_num_frames = 120 + if video_max_pixels is None: + video_max_pixels = 378 * 378 + + messages = [ + [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": video_binary, + "max_frames": video_max_num_frames, + "max_pixels": video_max_pixels, + } + ], + } + ], + ] + _, videos, video_kwargs = process_vision_info( + messages, + return_video_kwargs=True, + use_audio=use_audio, + audio_sample_rate=audio_sample_rate, + ) + output.videos.append(videos[0]) + video_len = round(videos[0].shape[0] / video_kwargs["fps"][0], 2) + output.videos_duration.append( + { + "video_duration": f"{video_len}s", + } + ) + + if use_audio and "audio_chunks" in video_kwargs: + audio_chunks = video_kwargs["audio_chunks"][0] + if audio_chunks is not None: + output.video_audios.append(audio_chunks) + else: + output.video_audios.append([]) + elif use_audio: + output.video_audios.append([]) + + vid_idx += 1 + + elif tag == "<|audio|>" or tag == "<|discrete_audio|>": + audio_path = turn["audio_urls"][aud_idx] + if isinstance(audio_path, str): + if "#" in audio_path: + compression_path, inner_path = audio_path.split("#", 1) + compression_path = os.path.join(img_dir, compression_path) + assert compression_path[-4:] in [ + ".zip", + ".tar", + ], f"unsupported compression format: {compression_path}" + with open(compression_path, "rb") as comp_file: + if compression_path.endswith(".zip"): + with zipfile.ZipFile(comp_file, "r") as zip_file: + with zip_file.open(inner_path) as audio_file: + audio_binary = audio_file.read() + elif compression_path.endswith(".tar"): + with tarfile.open( + fileobj=comp_file, mode="r" + ) as tar_file: + audio_file = tar_file.extractfile(inner_path) + audio_binary = audio_file.read() + else: + with open(os.path.join(img_dir, audio_path), "rb") as f: + audio_binary = f.read() + audio_stream = io.BytesIO(audio_binary) + else: + if isinstance(audio_path, (bytes, bytearray)): + audio_stream = io.BytesIO(audio_path) + else: + audio_stream = audio_path + + try: + import librosa + + y, sr = librosa.load( + audio_stream, sr=DEFAULT_SAMPLE_RATE, mono=True + ) + assert ( + DEFAULT_SAMPLE_RATE == sr + ), f"librosa resampling failed: {DEFAULT_SAMPLE_RATE} != {sr}" + except Exception as e: + raise ConditionalError( + f"audio decoding failed for {audio_path}: {e}" + ) + + audio_duration = len(y) / sr + if audio_duration < 0.5: + raise ConditionalError( + f"Audio too short ({audio_duration:.2f}s). Minimum 0.5s required." + ) + if audio_duration > 600: + raise ConditionalError( + f"Audio duration ({audio_duration:.2f}s) exceeds maximum allowed duration (600s)" + ) + + if len(y) < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES: + raise ConditionalError( + f"Audio too short ({len(y)} samples = {audio_duration:.4f}s < 0.1s). " + f"Minimum {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES} samples required for CosyVoice encoder." + ) + + if not hasattr(output, "audios"): + output.audios = [] + if not hasattr(output, "discrete_audios"): + output.discrete_audios = [] + + normalized_y = hpf_normalize(y) + normalized_y = torch.from_numpy(normalized_y).float() + + output.discrete_audios.append(normalized_y) + if tag == "<|audio|>": + + output.audios.append(y) + total_duration = len(y) / sr + output.audios_duration.append( + { + "duration": f"{(total_duration):.2f}s", + } + ) + + aud_idx += 1 + else: + raise ConditionalError( + f"{tag} is not in ['<|image|>', '<|video|>', '<|audio|>']" + ) + + return output + + @classmethod + def prompt_user( + cls, + output, + tokenizer=None, + turn: Optional[dict] = None, + content: Optional[str] = None, + is_train=False, + fixed_mime=False, + insert_ocr=300, + file_names: Optional[list[str]] = None, + mimes: Optional[list[str]] = None, + mm_tokens: Optional[list[str]] = None, + words: Optional[list] = None, + lens: Optional[list] = None, + query_template: Optional[list[str]] = None, + config: Optional[dict] = None, + seed: np.random.Generator = None, + ): + assert content or turn + if turn is None: + image_metas = [ + {"words": words[i], "lens": lens[i]} for i in range(len(words)) + ] + turn = { + "content": content, + "image_metas": image_metas, + } + if seed is None: + seed = np.random.default_rng() + + turn["content"] = re.sub(r"", "<|image|>", turn["content"]) + turn["content"] = re.sub(r"", "<|video|>", turn["content"]) + turn["content"] = re.sub(r"", "<|audio|>", turn["content"]) + + pattern = re.compile(r"(<\|video\|>|<\|image\|>|<\|audio\|>)") + + all_tags_in_order = [ + match.group() for match in pattern.finditer(turn["content"]) + ] + n_vids = sum(1 for tag in all_tags_in_order if tag == "<|video|>") + n_audios = sum(1 for tag in all_tags_in_order if tag == "<|audio|>") + + assert ( + len(turn.get("image_urls", [])) + + len(turn.get("video_urls", [])) + + len(turn.get("audio_urls", [])) + ) == len( + all_tags_in_order + ), f"Number of media URLs does not match number of media tags." + + if mm_tokens is None: + mm_tokens = [ + cls.audio_pad if tag == "<|audio|>" else cls.image_pad + for tag in all_tags_in_order + ] + + assert len(mm_tokens) == len(all_tags_in_order) + + if config.get("llava_pretrain", False): + mm_str = "".join([mm_tokens[i] for i in range(len(all_tags_in_order))]) + if hasattr(output, "input_str"): + output.input_str += mm_str + + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + return output + + if query_template: + processed_content = seed.choice(query_template).format(turn["content"]) + + tags_after_template = pattern.findall(processed_content) + if len(all_tags_in_order) != len(tags_after_template): + cleaned_template_text = pattern.sub("", processed_content) + processed_content = "".join(all_tags_in_order) + cleaned_template_text + turn["content"] = processed_content + + content_parts = pattern.split(turn["content"].strip()) + + if hasattr(output, "input_str"): + output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" + if getattr(output, "input_ids", None) is not None: + role_encoded = tokenizer.encode( + f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False + ) + output.input_ids += role_encoded + if turn.get("trainable_role", False): + output.label_ids += role_encoded + else: + output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] + + tag_cursor = 0 + + for part in content_parts: + part = part.strip() + + if not part: + continue + + if part not in ["<|image|>", "<|video|>", "<|audio|>"]: + content_text = part + + if hasattr(output, "input_str"): + output.input_str += "\n" + content_text + if getattr(output, "input_ids", None) is not None: + content_encoded = tokenizer.encode( + "\n" + content_text, truncation=False + ) + output.input_ids += content_encoded + if turn.get("trainable_content", False): + output.label_ids += content_encoded + else: + output.label_ids += [ + IGNORE_INDEX for _ in range(len(content_encoded)) + ] + continue + + if part == "<|image|>": + mime = Preprocessor.prompt_mime( + mimes=mimes if not file_names else None, + fixed_mime=fixed_mime if not file_names else False, + file_name=file_names[tag_cursor] if file_names else None, + tag_idx=output.sample_mm_counter["image"], + is_video=False, + is_audio=False, + seed=seed, + ) + mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" + discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" + vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" + mm_str = ( + cls.new_line + + mime_str + + cls.new_line + + discrete_image_str + + cls.new_line + + vector_str + ) + + if hasattr(output, "input_str"): + output.input_str += mm_str + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + + output.sample_mm_counter["image"] += 1 + tag_cursor += 1 + + elif part == "<|video|>": + mime = Preprocessor.prompt_mime( + mimes=mimes if not file_names else None, + fixed_mime=fixed_mime if not file_names else False, + file_name=file_names[tag_cursor] if file_names else None, + tag_idx=output.sample_mm_counter["video"], + is_video=True, + is_audio=False, + seed=seed, + ) + mm_str = "" + aux_inputs = { + "video_duration": output.videos_duration[ + output.sample_mm_counter["video"] + ]["video_duration"], + } + mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" + aux_str = f"{cls.aux_video_start}{cls.aux_vid_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_video_end}" + vector_str = f"{cls.video_start}{cls.video_pad}{cls.video_end}" + mm_str += ( + cls.new_line + + mime_str + + cls.new_line + + aux_str + + cls.new_line + + vector_str + ) + if hasattr(output, "input_str"): + output.input_str += mm_str + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + output.sample_mm_counter["video"] += 1 + tag_cursor += 1 + + elif part == "<|audio|>": + mime = Preprocessor.prompt_mime( + mimes=mimes if not file_names else None, + fixed_mime=fixed_mime if not file_names else False, + file_name=file_names[tag_cursor] if file_names else None, + tag_idx=output.sample_mm_counter["audio"], + is_video=False, + is_audio=True, + seed=seed, + ) + mm_str = "" + aux_inputs = { + "audio_duration": output.audios_duration[ + output.sample_mm_counter["audio"] + ]["duration"], + } + mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" + aux_str = f"{cls.aux_audio_start}{cls.aux_audio_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_audio_end}" + discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" + vector_str = f"{cls.audio_start}{cls.audio_pad}{cls.audio_end}" + mm_str += ( + cls.new_line + + mime_str + + cls.new_line + + aux_str + + cls.new_line + + discrete_audio_str + + cls.new_line + + vector_str + ) + if hasattr(output, "input_str"): + output.input_str += mm_str + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + + output.sample_mm_counter["audio"] += 1 + tag_cursor += 1 + + if hasattr(output, "input_str"): + output.input_str += cls.turn_suffix + + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) + output.input_ids += token_ids + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + + return output + + @classmethod + def prompt_assistant( + cls, + output, + tokenizer=None, + turn: Optional[dict] = None, + role: Optional[str] = "assistant", + content: Optional[str] = None, + is_last_turn=False, + is_eval=True, + is_llava_pretrain=False, + is_after_last_user_turn=False, + ): + assert content or turn + if turn is None: + turn = { + "content": content, + "role": role, + } + + if is_llava_pretrain: + if hasattr(output, "input_str"): + output.input_str += turn["content"] + if getattr(output, "input_ids", None) is not None: + content_encoded = tokenizer.encode(turn["content"], truncation=False) + output.input_ids += content_encoded + output.label_ids += content_encoded + return output + + reasoning_content = turn.get("reasoning_content", "") + if ( + not reasoning_content + and isinstance(turn["content"], str) + and "" in turn["content"] + ): + parts = turn["content"].split("", 1) + reasoning_content = parts[0].split("", 1)[-1].lstrip("\n") + turn["content"] = parts[1].lstrip("\n") + + if is_after_last_user_turn and (is_last_turn or reasoning_content): + content_to_strip = turn.get("content") or "" + stripped_content = content_to_strip.lstrip("\n") + + if reasoning_content is None: + reasoning_content = "" + turn["content"] = ( + f"\n{reasoning_content.strip()}\n\n\n{stripped_content}" + ) + + if turn.get("tool_calls"): + for tool_call in turn["tool_calls"]: + func_name = tool_call.get("function", {}).get("name", "") + args = tool_call.get("function", {}).get("arguments", {}) + + if isinstance(args, str): + try: + args = json.loads(args) + except Exception: + pass + if not isinstance(args, dict): + print( + f"[error] tool_call.function.arguments가 dict이 아님: type={type(args)}, value={str(args)}" + ) + assert ( + False + ), "tool_call.function.arguments는 dict이거나 dict를 나타내는 JSON 문자열이어야 합니다." + + tool_turn_content = f"\n{func_name}\n" + + for key, value in args.items(): + arg_value = ( + json.dumps(value, ensure_ascii=False) + if not isinstance(value, str) + else value + ) + tool_turn_content += f"{key}\n{arg_value}\n" + tool_turn_content += "" + + if func_name == "t2i_model_generation": + assert ( + "<|t2i_model_generation_target_discrete_image|>" + in turn["content"] + ), "t2i_model_generation tool call must have target discrete image tag in content." + turn["content"] = turn["content"].replace( + "<|t2i_model_generation_target_discrete_image|>", + tool_turn_content, + ) + else: + turn["content"] += tool_turn_content + + pattern = re.compile( + r"(<\|image\|>|<\|discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>)" + ) + all_tags_in_order = [ + match.group() for match in pattern.finditer(turn["content"]) + ] + + assert ( + len(turn.get("image_urls", [])) + + len(turn.get("video_urls", [])) + + len(turn.get("audio_urls", [])) + ) == len( + all_tags_in_order + ), f"Number of media URLs does not match number of media tags." + + if hasattr(output, "input_str"): + output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" + if is_eval and is_last_turn: + if reasoning_content.strip() == "": + output.input_str += f"\n\n\n\n" + turn["content"] = stripped_content + else: + output.input_str += f"{turn['content']}{cls.turn_suffix}" + + if getattr(output, "input_ids", None) is not None: + role_encoded = tokenizer.encode( + f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False + ) + output.input_ids += role_encoded + + if is_eval and is_last_turn: + if reasoning_content.strip() == "": + output.input_ids += tokenizer.encode( + f"\n\n\n\n", truncation=False + ) + turn["content"] = stripped_content + else: + if turn.get("trainable_role", True): + output.label_ids += role_encoded + else: + output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] + + turn_img_idx = 0 + content_parts = pattern.split(turn["content"].strip()) + for part in content_parts: + part = part.strip() + + if not part: + continue + + if part not in [ + "<|image|>", + "<|discrete_image|>", + "<|audio|>", + "<|discrete_audio|>", + ]: + content_text = part + + if hasattr(output, "input_str"): + output.input_str += "\n" + content_text + if getattr(output, "input_ids", None) is not None: + content_encoded = tokenizer.encode( + "\n" + content_text, truncation=False + ) + output.input_ids += content_encoded + if turn.get("trainable_content", True): + output.label_ids += content_encoded + else: + output.label_ids += [ + IGNORE_INDEX for _ in range(len(content_encoded)) + ] + continue + + if part == "<|image|>": + file_name = turn.get("image_urls", [])[turn_img_idx] + if isinstance(file_name, str) and "#" in file_name: + file_name = file_name.split("#")[-1] + file_name = os.path.basename(file_name) + mime = Preprocessor.prompt_mime( + mimes=None, + fixed_mime=False, + file_name=file_name, + tag_idx=output.sample_mm_counter["image"], + is_video=False, + is_audio=False, + seed=None, + ) + mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" + discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" + vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" + mm_str = ( + cls.new_line + + mime_str + + cls.new_line + + discrete_image_str + + cls.new_line + + vector_str + ) + + if hasattr(output, "input_str"): + output.input_str += mm_str + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + output.label_ids += [ + IGNORE_INDEX for _ in range(len(token_ids)) + ] + turn_img_idx += 1 + output.sample_mm_counter["image"] += 1 + + elif part == "<|discrete_image|>": + discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" + mm_str = cls.new_line + discrete_image_str + if hasattr(output, "input_str"): + output.input_str += mm_str + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + output.label_ids += token_ids + turn_img_idx += 1 + + elif part == "<|discrete_audio|>": + discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" + mm_str = cls.new_line + discrete_audio_str + if hasattr(output, "input_str"): + output.input_str += mm_str + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(mm_str, truncation=False) + output.input_ids += token_ids + if turn.get("trainable_content", True): + output.label_ids += token_ids + else: + output.label_ids += [ + IGNORE_INDEX for _ in range(len(token_ids)) + ] + + elif part == "<|audio|>": + raise Exception( + "Assistant turn에서 <|audio|> 태그는 지원하지 않음. discrete_audio 만 지원함." + ) + + if hasattr(output, "input_str"): + output.input_str += cls.turn_suffix + + if getattr(output, "input_ids", None) is not None: + token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) + output.input_ids += token_ids + if turn.get("trainable_content", True): + output.label_ids += token_ids + else: + output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] + + return output + + @classmethod + def prompt_tool( + cls, + output, + tokenizer=None, + turn: Optional[dict] = None, + role: Optional[str] = None, + content: Optional[str] = None, + eot: Optional[bool] = None, + need_start_tag=True, + need_end_tag=True, + ): + assert (content and role) or turn + if turn is None: + turn = { + "content": content, + "role": role, + "endofturn": eot, + } + assert ( + "tool" == turn["role"] + ), f'[warning] unexpected turn["role"]: {turn["role"]}' + content_value = turn.get("content", "") + + if isinstance(content_value, dict): + if "response" in content_value: + content_str = content_value["response"] + else: + content_str = json.dumps(content_value, ensure_ascii=False) + elif isinstance(content_value, str): + try: + parsed = json.loads(content_value) + if isinstance(parsed, dict): + if "response" in parsed: + content_str = parsed["response"] + else: + content_str = json.dumps(parsed, ensure_ascii=False) + else: + content_str = content_value + except (json.JSONDecodeError, TypeError): + content_str = content_value + else: + content_str = str(content_value) + + turn["content"] = ( + f"{turn.get('name', '')}\n{content_str}\n" + ) + + if hasattr(output, "input_str"): + if need_start_tag: + output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" + output.input_str += f"{cls.new_line}{turn['content']}" + if need_end_tag: + output.input_str += cls.turn_suffix + + if getattr(output, "input_ids", None) is not None: + if need_start_tag: + role_encoded = tokenizer.encode( + f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False + ) + output.input_ids += role_encoded + + if turn.get("trainable_role", True): + output.label_ids += role_encoded + else: + output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] + + content = f"{cls.new_line}{turn['content']}" + content_encoded = tokenizer.encode(content, truncation=False) + if need_end_tag: + content_encoded += tokenizer.encode( + f"{cls.turn_suffix}", truncation=False + ) + output.input_ids += content_encoded + if turn.get("trainable_content", True): + output.label_ids += content_encoded + else: + output.label_ids += [ + IGNORE_INDEX for _ in range(len(content_encoded)) + ] + return output + + @classmethod + def prompt_etc( + cls, + output, + tokenizer=None, + turn: Optional[dict] = None, + role: Optional[str] = None, + content: Optional[str] = None, + eot: Optional[bool] = None, + ): + assert (content and role) or turn + if turn is None: + turn = { + "content": content, + "role": role, + "endofturn": eot, + } + print(f'[warning] unexpected turn["role"]: {turn["role"]}') + + if hasattr(output, "input_str"): + output.input_str += f"{cls.turn_prefix}{turn['role']}\n" + output.input_str += f"{turn['content']}{cls.turn_suffix}" + if turn.get("stop", False): + output.input_str += cls.stop_token + if turn.get("endofturn", False): + output.input_str += cls.eot + + if getattr(output, "input_ids", None) is not None: + role_encoded = tokenizer.encode( + f"{cls.turn_prefix}{turn['role']}\n", truncation=False + ) + output.input_ids += role_encoded + + if turn.get("trainable_role", True): + output.label_ids += role_encoded + else: + output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] + + content = f"{turn['content']}{cls.turn_suffix}" + if turn.get("stop", False): + content += cls.stop_token + if turn.get("endofturn", False): + content += cls.eot + content_encoded = tokenizer.encode(content, truncation=False) + output.input_ids += content_encoded + if turn.get("trainable_content", True): + output.label_ids += content_encoded + else: + output.label_ids += [IGNORE_INDEX for _ in range(len(content_encoded))] + return output + + def __call__(self, sample): + return self.preprocess_new(sample) + + @classmethod + def batchify( + cls, + items: List[Dict[str, Any],], + device: str = None, + ): + batch = dict() + for item in items: + for k, v in item.items(): + if isinstance(v, torch.Tensor): + if device is not None: + v = v.to(device=device) + elif k == "pixel_values": + v = [_v.to(device=device) for _v in v] + + if k not in batch: + batch[k] = [ + v, + ] + else: + batch[k].append(v) + + for k, v in batch.items(): + if isinstance(v[0], torch.Tensor): + if k in ["image_grid_thw", "video_grid_thw"]: + batch[k] = torch.cat(v, dim=0) + continue + batch[k] = torch.stack(v, dim=0) + batch["video_grid_thw"] = None + batch["pixel_values_videos"] = None + return batch + + def convert_wds_to_datalake( + self, + img: Union[PIL.Image.Image, Dict[str, PIL.Image.Image]] = {}, + json: Dict[str, Any] = {}, + benchmark: Optional[str] = None, + video: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, + audio: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, + ): + + if "lines" in json: + del json["lines"] + if "paragraphs" in json: + del json["paragraphs"] + + assert json["meta"]["type"] in [ + "caption", + "vqa", + "textread", + ], f"{json['meta']['path']}, {json['meta']['type']}: The dataset type should be one of them: caption, vqa, textread." + + sample = {"vlm": {}} + sample["vlm"] = get_wds_default_config( + json["meta"], existing_default_config=self.wds_default_config + ) + sample["vlm"]["data_name"] = json["meta"].get("name", "unk") + + sample["vlm"]["data_type"] = ( + "wds" + if (isinstance(img, PIL.Image.Image) and img) + or (isinstance(img, dict) and len(img) > 0) + else "sft1" + ) + + sample["vlm"]["sample_id"] = json.get("qa_id", None) + sample["vlm"]["category"] = json.get("category", None) + sample["vlm"]["data_info"] = json.get("data_info", dict()) + sample["vlm"]["options"] = None + if "choices_en" in sample["vlm"]["data_info"]: + if sample["vlm"]["options"] is None and json["meta"]["lang"] == "en": + sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_en"] + sample["vlm"]["options_en"] = sample["vlm"]["data_info"]["choices_en"] + if "choices_ko" in sample["vlm"]["data_info"]: + if sample["vlm"]["options"] is None and json["meta"]["lang"] == "ko": + sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_ko"] + sample["vlm"]["options_ko"] = sample["vlm"]["data_info"]["choices_ko"] + sample["vlm"]["image_index"] = json.get( + "image_index", json.get("img_url", None) + ) + + if sample["vlm"].get("video", False): + is_multi_image_dataset = False + else: + is_multi_image_dataset, img, json = convert_format_for_multi_image( + img, json + ) + + if json["meta"]["type"] == "textread": + key = "words" + elif json["meta"].get("subtask", "") == "region": + key = f"regions_{json['meta']['lang']}" + elif json["meta"]["type"] == "vqa": + key = f"qa_pairs_{json['meta']['lang']}" + elif json["meta"]["type"] == "caption": + key = f"captions_{json['meta']['lang']}" + else: + raise ConditionalError( + f"wrong task type in wds config: {sample['vlm']['data_name']}" + ) + + turns = [ + { + "role": "tool_list", + "content": "", + "content_type": "text", + "trainable_role": False, + "trainable_content": False, + "stop": False, + "debuggingInfo": {}, + "meta": {}, + "candidates": [], + "endofturn": False, + }, + { + "role": "system", + "content_type": "text", + "candidates": [], + "trainable_role": False, + "trainable_content": False, + "stop": False, + "debuggingInfo": {}, + "meta": {}, + "content": "", + "endofturn": False, + }, + ] + + if json["meta"].get("llava_pretrain", False): + sample["vlm"]["llava_pretrain"] = True + + use_task_prompt = json["meta"].get( + "use_task_prompt", self.wds_default_config["use_task_prompt"] + ) + get_random = json["meta"].get( + "get_random", self.wds_default_config["get_random"] + ) + reasoning = json["meta"].get("reasoning", self.wds_default_config["reasoning"]) + + try: + if key not in json: + key = key[:-3] + assert key in json + if len(json[key]) == 0: + key = key[:-3] + assert key in json + except: + raise ConditionalError( + f"{key} key is not in json? dataset name: {sample['vlm']['data_name']}" + ) + + first_turn = True + if "region" in key: + json[key] = json[key]["00"] + sample["vlm"]["multiturn_n_samples"] = 1 + if ( + not is_multi_image_dataset + and sample["vlm"]["multiturn_n_samples"] > 1 + or "region" in key + ): + json[key] = sampling_multiturn_single_img( + json[key], + sample["vlm"]["multiturn_n_samples"], + sample["vlm"]["multiturn_preserve_order"], + sample["vlm"]["multiturn_continuous"], + ) + + if sample["vlm"].get("video", False): + for qa in json[key]: + vid_src = [] + user = { + "role": "user", + "content_type": "text", + "candidates": [], + "trainable_role": False, + "trainable_content": False, + "stop": False, + "debuggingInfo": {}, + "meta": {}, + "image_urls": [], + "image_metas": [], + "video_urls": [], + "video_metas": [], + "audio_urls": [], + "audio_metas": [], + "content": "", + "endofturn": False, + } + + instruct_prompt, task_prompt = hcx_vision_prompter( + task=json["meta"]["type"], + subtask=json["meta"].get("subtask", None), + lang=json["meta"]["lang"], + get_random=get_random, + use_task_prompt=use_task_prompt, + ) + + prompt = qa[0] + answer = qa[-1] if reasoning else qa[1] + + if first_turn: + user["video_metas"].append({"lens": []}) + user["content"] += "<|video|>" + prompt = task_prompt.format(prompt) + + if "entities" in json: + user["video_metas"][0]["lens"] = json["entities"].get("00", []) + if isinstance(video, dict): + vid_src.append(video["00"]) + else: + vid_src.append(video) + first_turn = False + + user["video_urls"] = vid_src + user["content"] += prompt + + assistant = { + "candidates": [], + "content": answer, + "content_type": "text", + "debuggingInfo": {}, + "meta": {}, + "role": "assistant", + "trainable_content": True, + "trainable_role": True, + "stop": False, + "endofturn": True, + } + turns.append(user) + turns.append(assistant) + + else: + if key.startswith("qa_pairs") or key.startswith("captions"): + if self.mode != "train" and key.startswith("qa_pairs"): + qas = dict() + for qa in json[key]: + q = qa[0] + if q not in qas: + qas[q] = list() + for _i, _e in enumerate(qa[1:]): + if len(qas[q]) <= _i: + qas[q].append(list()) + qas[q][_i].append(_e) + json[key] = [ + [ + k, + ] + + v + for k, v in qas.items() + ] + + if self.mode != "train": + json[key] = json[key][:1] + + for qa in json[key]: + img_src = [] + user = { + "role": "user", + "content_type": "text", + "candidates": [], + "trainable_role": False, + "trainable_content": False, + "stop": False, + "debuggingInfo": {}, + "meta": {}, + "image_urls": [], + "image_metas": [], + "video_urls": [], + "video_metas": [], + "audio_urls": [], + "audio_metas": [], + "content": "", + "endofturn": False, + } + img_keys = re.findall(r"", qa[0]) + video_keys = re.findall(r"", qa[0]) + audio_keys = re.findall(r"", qa[0]) + + if key.startswith("qa_pairs"): + if len(qa) > 2: + sample_id = qa[2] + if ( + isinstance(sample_id, (list, tuple)) + and len(sample_id) > 0 + ): + sample_id = sample_id[0] + sample["vlm"]["sample_id"] = sample_id + + instruct_prompt, task_prompt = hcx_vision_prompter( + task=json["meta"]["type"], + subtask=json["meta"].get("subtask", None), + lang=json["meta"]["lang"], + get_random=get_random, + use_task_prompt=use_task_prompt, + ) + if json["meta"]["type"] == "vqa": + prompt = qa[0] + answer = qa[-1] if reasoning else qa[1] + elif json["meta"]["type"] == "caption": + prompt = task_prompt.format("") + answer = qa + + if first_turn or self.mode != "train": + if json["meta"]["type"] == "vqa": + prompt = task_prompt.format(prompt) + if first_turn and not is_multi_image_dataset: + user["image_metas"].append({"words": [], "lens": []}) + if "" in prompt: + prompt = prompt.replace("", "<|image|>") + else: + user["content"] += "<|image|>" + user["image_metas"][0]["words"] = json.get("words", {}).get( + "00", [] + ) + if "objects" in json: + user["image_metas"][0]["lens"] = json["objects"].get( + "00", [] + ) + elif "entities" in json: + user["image_metas"][0]["lens"] = json["entities"].get( + "00", [] + ) + if isinstance(img, dict): + img_src.append(img["00"]) + else: + img_src.append(img) + elif len(img_keys) > 0: + for i, key in enumerate(img_keys): + user["image_metas"].append({"words": [], "lens": []}) + if f"" in prompt: + prompt = prompt.replace(f"", "<|image|>") + else: + user["content"] += "<|image|>" + img_src.append(img[key]) + _words = json.get("words", {}) + if isinstance(_words, dict): + _words = _words.get(key, []) + user["image_metas"][i]["words"] = _words + if "objects" in json: + _objects = json["objects"].get(key, []) + if isinstance(_objects, dict): + _objects = _objects.get(key, []) + user["image_metas"][i]["lens"] = _objects + if "entities" in json: + _entities = json["entities"].get(key, []) + if isinstance(_entities, dict): + _entities = _entities.get(key, []) + user["image_metas"][i]["lens"] = _entities + user["image_urls"] = img_src + + if len(audio_keys) > 0: + for i, key in enumerate(audio_keys): + if isinstance(audio, dict): + user["audio_urls"].append(audio[key]) + else: + user["audio_urls"].append(audio) + user["audio_metas"].append( + { + "format": "wav", + "note": "This audio sample is passed to convert_wds_to_datalake function.", + } + ) + if f"" in prompt: + prompt = prompt.replace(f"", "<|audio|>") + else: + user["content"] += "<|audio|>" + + user["content"] += prompt + + content, candidates = None, list() + if self.mode != "train": + if isinstance(answer, (int, float)): + pass + elif isinstance(answer, str): + if answer != "None": + try: + answer = ast.literal_eval(answer) + except Exception as ex: + pass + if not isinstance(answer, (list, tuple)): + answer = [ + answer, + ] + candidates += answer[1:] + answer = answer[0] + content = answer + elif isinstance(answer, (list, tuple)): + for _idx, _answer in enumerate(answer): + if isinstance(_answer, str): + if isinstance(benchmark, str) and benchmark in [ + "textvqa", + ]: + try: + _answer = ast.literal_eval(_answer) + except Exception as ex: + pass + if isinstance(_answer, dict): + _answer = str(_answer) + if not isinstance(_answer, (list, tuple)): + _answer = [ + _answer, + ] + if _idx == 0: + content = _answer[0] + candidates += _answer[1:] + else: + candidates += _answer + + if isinstance(content, (int, float)): + content = str(content) + assert content is None or isinstance(content, str) + for _idx, _candidate in enumerate(candidates): + if isinstance(_candidate, (int, float)): + candidates[_idx] = str(_candidate) + assert isinstance(candidates[_idx], str) + mcqa_gt = sample["vlm"]["data_info"].get("choice_answer", None) + if isinstance(mcqa_gt, str): + content = mcqa_gt + + assistant = { + "candidates": candidates, + "content": answer if self.mode == "train" else content, + "content_type": "text", + "debuggingInfo": {}, + "meta": {}, + "role": "assistant", + "trainable_content": True, + "trainable_role": True, + "stop": False, + "endofturn": True, + } + turns.append(user) + turns.append(assistant) + + elif key == "words": + img_src = [] + user = { + "role": "user", + "content_type": "text", + "candidates": [], + "trainable_role": False, + "trainable_content": False, + "stop": False, + "debuggingInfo": {}, + "meta": {}, + "image_urls": [], + "image_metas": [], + "video_urls": [], + "video_metas": [], + "audio_urls": [], + "audio_metas": [], + "content": "<|image|>", + "endofturn": False, + } + instruct_prompt, task_prompt = hcx_vision_prompter( + task=json["meta"]["type"], + subtask=json["meta"].get("subtask", None), + lang=json["meta"]["lang"], + get_random=get_random, + use_task_prompt=use_task_prompt, + ) + user["content"] += task_prompt + user["image_metas"].append({"words": [], "lens": []}) + user["image_metas"][0]["words"] = json["words"]["00"] + if "entities" in json: + user["image_metas"][0]["lens"] = json["entities"].get("00", []) + img_src.append(img["00"]) + user["image_urls"] = img_src + + words_list = [ + d["text"].strip() for d in json["words"]["00"] if d["text"] + ] + gt = " ".join(words_list) + assistant = { + "candidates": [], + "content": gt, + "content_type": "text", + "debuggingInfo": {}, + "meta": {}, + "role": "assistant", + "trainable_content": True, + "trainable_role": True, + "stop": False, + "endofturn": True, + } + turns.append(user) + turns.append(assistant) + + elif key.startswith("regions"): + for region in json[key]: + img_src = [] + user = { + "role": "user", + "content_type": "text", + "candidates": [], + "trainable_role": False, + "trainable_content": False, + "stop": False, + "debuggingInfo": {}, + "meta": {}, + "image_urls": [], + "image_metas": [], + "video_urls": [], + "video_metas": [], + "audio_urls": [], + "audio_metas": [], + "content": "<|image|><|region|>", + "endofturn": False, + } + instruct_prompt, task_prompt = hcx_vision_prompter( + task=json["meta"]["type"], + subtask=json["meta"].get("subtask", None), + lang=json["meta"]["lang"], + get_random=get_random, + use_task_prompt=use_task_prompt, + ) + sample["vlm"]["query_template"] = [task_prompt] + user["image_metas"].append({"words": [], "lens": []}) + user["image_metas"][0]["region"] = region + if "words" in json: + user["image_metas"][0]["words"] = json["words"].get("00", []) + if "objects" in json: + user["image_metas"][0]["lens"] = json["objects"].get("00", []) + if "entities" in json: + user["image_metas"][0]["lens"] = json["entities"].get("00", []) + img_src.append(img["00"]) + user["image_urls"] = img_src + + assistant = { + "candidates": [], + "content": region["text"], + "content_type": "text", + "debuggingInfo": {}, + "meta": {}, + "role": "assistant", + "trainable_content": True, + "trainable_role": True, + "stop": False, + "endofturn": True, + } + turns.append(user) + turns.append(assistant) + else: + raise ConditionalError( + f"wrong task type in wds config: {sample['vlm']['data_name']}" + ) + sample["data"] = turns + return sample + + def preprocess_new(self, sample): + + config = sample.get("vlm", {}) + if config["data_type"] in ["sft1", "datalake"]: + default_config = copy.deepcopy(self.default_config) + default_config.update(config) + config = default_config + idx_for_debug = sample.get("idx", -1) + turns = sample["data"] if "data" in sample else sample["messages"] + + if self.random_system_prompt and self.rng.random() < config.get( + "random_system_prob", 0.0 + ): + for turn in turns: + if turn["role"] == "system": + turn["content"] = self.random_system_prompt + break + + if sample.get("tools", None) is None: + sample["tools"] = [] + + if len(sample["tools"]) == 0: + if ( + self.rng.random() < config.get("random_tool_prob", 0.005) + and len(self.common_tools) > 0 + ): + + max_n_tools = min(7, len(self.common_tools)) + tool_counts = np.arange(1, max_n_tools + 1) + tool_count_weights = 1.0 / tool_counts + tool_count_weights = tool_count_weights / tool_count_weights.sum() + n_tools = int(self.rng.choice(tool_counts, p=tool_count_weights)) + + idxs = np.arange(len(self.common_tools)) + weights = 1.0 / (idxs + 1) + weights[0] += 1.0 + weights = weights / weights.sum() + + chosen_indices = self.rng.choice( + len(self.common_tools), size=n_tools, replace=False, p=weights + ) + + self.rng.shuffle(chosen_indices) + + sample["tools"] = [self.common_tools[i] for i in chosen_indices] + + if "tools" in sample and sample["tools"]: + tool_prompt = [] + tool_prompt.append("# Tools\n\n") + tool_prompt.append( + "You may call one or more functions to assist with the user query.\n\n" + ) + tool_prompt.append( + "You are provided with function signatures within XML tags:\n" + ) + tool_prompt.append("\n") + for tool in sample["tools"]: + tool_prompt.append(json.dumps(tool, ensure_ascii=False)) + tool_prompt.append("\n\n\n") + tool_prompt.append( + "For each function call, output the function name and arguments within the following XML format:\n" + ) + tool_prompt.append("{function-name}\n") + tool_prompt.append("{arg-key-1}\n") + tool_prompt.append("{arg-value-1}\n") + tool_prompt.append("{arg-key-2}\n") + tool_prompt.append("{arg-value-2}\n") + tool_prompt.append("...\n") + tool_prompt.append("") + + tool_prompt = "".join(tool_prompt) + else: + tool_prompt = "" + + multiturn_n_sample = config.get("multiturn_n_samples", 0) + if multiturn_n_sample > 0 and self.mode == "train": + turns = self._sampling_multiturn( + turns, + multiturn_n_sample, + multiturn_preserve_order=config.get("multiturn_preserve_order", True), + multiturn_continuous=config.get("multiturn_continuous", False), + ) + + for i, turn in enumerate(turns): + if turn["role"] == "user": + if "img_src" in turn: + turns[i]["image_urls"] = turn["img_src"] + turns[i]["image_metas"] = turn["meta"] + for j, turn_img_meta in enumerate(turns[i]["image_metas"]): + if "entities" in turn_img_meta: + turns[i]["image_metas"][j]["lens"] = turn_img_meta[ + "entities" + ] + turns[i]["meta"] = {} + + max_image_cnt = config.get("max_image_cnt", 20) + if max_image_cnt > 0 and config["data_type"] != "sft1": + n_imgs = {} + for i, turn in enumerate(turns): + if turn["role"] == "user": + n_imgs[i] = len(turn.get("image_urls", [])) + assert ( + n_imgs[i] <= max_image_cnt + ), "skip sample if image_nums exceeds max_image_count per turn" + + if sum(n_imgs.values()) > max_image_cnt: + img_count = 0 + for k, v in reversed(list(n_imgs.items())): + img_count += v + if img_count > max_image_cnt: + break + + img_count = sum(n_imgs.values()) - max_image_cnt + + for i in range(k + 1): + if turns[i]["role"] == "user": + turns[i]["content"], n_removed1 = re.subn( + r"", + "", + turns[i]["content"].strip(), + count=img_count, + ) + img_count -= n_removed1 + turns[i]["content"], n_removed2 = re.subn( + r"<\|image\|>", + "", + turns[i]["content"].strip(), + count=img_count, + ) + img_count -= n_removed2 + n_removed_imgs = n_removed1 + n_removed2 + turns[i]["image_urls"] = turns[i]["image_urls"][n_removed_imgs:] + + if n_removed_imgs > 0 and len(turns[i]["image_urls"]) == 0: + idx = i + while True: + idx += 1 + turns[idx]["trainable_role"] = False + turns[idx]["trainable_content"] = False + if turns[idx]["role"] == "assistant": + break + + n_imgs_after = {} + for i, turn in enumerate(turns): + if turn["role"] == "user": + n_imgs_after[i] = len(turn.get("image_urls", [])) + assert sum(n_imgs_after.values()) > 0, "The n_imgs of vlm data is zero." + + n_mm_after = {} + for i, turn in enumerate(turns): + if turn["role"] == "user" or turn["role"] == "assistant": + n_mm_after[i] = ( + len(turn.get("image_urls", [])) + + len(turn.get("video_urls", [])) + + len(turn.get("audio_urls", [])) + ) + assert sum(n_mm_after.values()) > 0, "The n_mm of omni data is zero." + + queries, gts = list(), list() + output = Processed_sample( + input_str="", + input_ids=[], + label_ids=[], + imgs=[], + discrete_imgs=[], + videos=[], + videos_duration=[], + video_audios=[], + audios=[], + audios_duration=[], + discrete_audios=[], + sample_mm_counter={ + "image": 0, + "video": 0, + "audio": 0, + }, + ) + system_role_count = 0 + last_user_idx = max( + (i for i, d in enumerate(turns) if d.get("role") == "user"), default=-1 + ) + for i, turn in enumerate(turns): + if turn["role"] == "tool_list": + continue + + elif turn["role"] == "system": + if config.get("llava_pretrain", False): + continue + output = Preprocessor.prompt_system( + turn=turn, + output=output, + tokenizer=self.tokenizer, + seed=self.rng, + tool_prompt=tool_prompt, + system_role_count=system_role_count, + ) + system_role_count += 1 + + elif turn["role"].startswith("user"): + output = Preprocessor.load_mm( + output=output, + img_dir=config.get("img_dir", ""), + turn=turn, + prepare_input_fn=self.prepare_input_fn, + max_image_cnt=max_image_cnt, + video_max_num_frames=self.video_max_num_frames, + video_max_pixels=self.video_max_pixels, + use_audio=self.train_audio, + ) + output = Preprocessor.prompt_user( + output=output, + tokenizer=self.tokenizer, + turn=turn, + is_train=True if self.mode == "train" else False, + fixed_mime=config.get("fixed_mime", False), + mimes=self.mimes, + query_template=config.get("query_template", None), + config=config, + seed=self.rng, + ) + + queries.append(turn["content"].replace("<|image|>", "").strip()) + elif turn["role"].startswith("assistant"): + output = Preprocessor.load_mm( + output=output, + img_dir=config.get("img_dir", ""), + turn=turn, + prepare_input_fn=self.prepare_input_fn, + max_image_cnt=max_image_cnt, + video_max_num_frames=self.video_max_num_frames, + video_max_pixels=self.video_max_pixels, + use_audio=self.train_audio, + ) + + is_after_last_user = i > last_user_idx + is_first_assistant_after_last_user = False + if is_after_last_user: + is_first_assistant_after_last_user = all( + turns[j]["role"] != "assistant" + for j in range(last_user_idx + 1, i) + ) + + output = Preprocessor.prompt_assistant( + output=output, + tokenizer=self.tokenizer, + turn=turn, + is_last_turn=is_first_assistant_after_last_user, + is_eval=True if self.mode != "train" else False, + is_llava_pretrain=config.get("llava_pretrain", False), + is_after_last_user_turn=is_after_last_user, + ) + _gts = turn["content"] + if isinstance(_gts, str): + _gts = [ + _gts, + ] + if "candidates" in turn and len(turn["candidates"]) > 0: + for _candidates in turn["candidates"]: + if isinstance(_candidates, str): + _gts += [ + _candidates, + ] + elif isinstance(turn["candidates"][0], (list, tuple)): + _gts += _candidates + gts.append(_gts) + elif turn["role"] == "tool": + if config.get("llava_pretrain", False): + continue + + output = Preprocessor.prompt_tool( + output=output, + tokenizer=self.tokenizer, + turn=turn, + need_start_tag=( + True + if (i == 0 or turns[i - 1].get("role") != "tool") + else False + ), + need_end_tag=( + True + if (i == (len(turns) - 1) or turns[i + 1].get("role") != "tool") + else False + ), + ) + else: + if config.get("llava_pretrain", False): + continue + + import pdb + import sys + + class ForkedPdb(pdb.Pdb): + """A Pdb subclass that may be used from a forked multiprocessing child""" + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + ForkedPdb().set_trace() + output = Preprocessor.prompt_etc( + output=output, + tokenizer=self.tokenizer, + turn=turn, + ) + + pixel_values = [] + mm_query_lengths = [] + discrete_pixel_values = [] + image_ratios = [] + discrete_image_query_lengths = [] + + labels = output.label_ids + input_ids = output.input_ids + total_mm_query_length = 0 + + is_sft1 = False + if config["data_type"] == "sft1": + if self.sequence_parallel_size > 1: + if len(input_ids) % self.sequence_parallel_size != 0: + input_ids += [self.tokenizer.pad_token_id] * ( + self.sequence_parallel_size + - (len(input_ids) % self.sequence_parallel_size) + ) + labels += [IGNORE_INDEX] * ( + self.sequence_parallel_size + - (len(labels) % self.sequence_parallel_size) + ) + + input_ids = input_ids[ + : (len(input_ids) // self.sequence_parallel_size) + * self.sequence_parallel_size + ] + labels = labels[ + : (len(labels) // self.sequence_parallel_size) + * self.sequence_parallel_size + ] + + input_ids = torch.tensor(input_ids[-self.decoder_max_length :]) + labels = torch.tensor(labels[-self.decoder_max_length :]) + is_sft1 = True + + dummy_preprocess_results = self.prepare_input_fn.image_processor( + Image.new("RGB", (224, 224), (0, 0, 0)) + ) + dummy_pixel_values = torch.from_numpy( + np.concatenate([dummy_preprocess_results.pixel_values], axis=0) + ) + dummy_grid_thw = torch.from_numpy( + np.concatenate([dummy_preprocess_results.image_grid_thw], axis=0) + ) + + image_grid_thw = [] + for img in output.imgs: + w, h = img.size + + img = self._resize_min_edge(img) + preprocess_results = self.prepare_input_fn.image_processor([img]) + pixel_values.append(preprocess_results.pixel_values) + image_grid_thw.append(preprocess_results.image_grid_thw) + mm_query_lengths.append(preprocess_results.pixel_values.shape[0] // 4) + + if len(output.imgs) == 0: + pixel_values = torch.zeros(0, 1176) + image_grid_thw = torch.zeros(0, 3, dtype=torch.long) + else: + pixel_values = torch.from_numpy(np.concatenate(pixel_values, axis=0)) + image_grid_thw = torch.from_numpy(np.concatenate(image_grid_thw, axis=0)) + + for img in output.discrete_imgs: + w, h = img.size + + img_ratio = self._find_best_ratio_token([h, w]) + image_ratios.append(img_ratio) + discrete_pixel_value = img.resize((384, 384), Image.BICUBIC) + discrete_pixel_tensor = to_tensor(discrete_pixel_value) + + assert discrete_pixel_tensor.shape == ( + 3, + 384, + 384, + ), f"Unexpected discrete_pixel_tensor shape: {discrete_pixel_tensor.shape}" + assert not torch.isnan( + discrete_pixel_tensor + ).any(), "discrete_pixel_tensor contains NaN" + assert not torch.isinf( + discrete_pixel_tensor + ).any(), "discrete_pixel_tensor contains Inf" + pixel_min = discrete_pixel_tensor.min().item() + pixel_max = discrete_pixel_tensor.max().item() + assert ( + 0.0 <= pixel_min <= 1.0 and 0.0 <= pixel_max <= 1.0 + ), f"discrete_pixel_tensor values out of range [0, 1]: min={pixel_min}, max={pixel_max}" + + discrete_pixel_values.append(discrete_pixel_tensor) + discrete_image_query_lengths.append(729) + + if len(output.discrete_imgs) == 0: + discrete_pixel_values = torch.zeros(0, 3, 384, 384) + else: + discrete_pixel_values = torch.stack(discrete_pixel_values, dim=0) + + assert discrete_pixel_values.shape[1:] == ( + 3, + 384, + 384, + ), f"Unexpected stacked discrete_pixel_values shape: {discrete_pixel_values.shape}" + assert not torch.isnan( + discrete_pixel_values + ).any(), "Stacked discrete_pixel_values contains NaN" + assert not torch.isinf( + discrete_pixel_values + ).any(), "Stacked discrete_pixel_values contains Inf" + + pixel_values_videos = None + video_grid_thw = None + if self.train_video: + pixel_values_videos = [] + video_grid_thw = [] + video_query_lengths = [] + for video in output.videos: + preprocess_results = self.prepare_input_fn.video_processor([video]) + pixel_values_videos.append(preprocess_results.pixel_values_videos) + video_grid_thw.append(preprocess_results.video_grid_thw) + video_query_lengths.append( + preprocess_results.pixel_values_videos.shape[0] // 4 + ) + if len(output.videos) == 0: + pixel_values_videos = torch.zeros(0, 1176) + video_grid_thw = torch.zeros(0, 3, dtype=torch.long) + else: + pixel_values_videos = torch.from_numpy( + np.concatenate(pixel_values_videos, axis=0) + ) + video_grid_thw = torch.from_numpy( + np.concatenate(video_grid_thw, axis=0) + ) + + video_audio_values = [] + video_audio_masks = [] + video_audio_query_lengths = [] + if self.train_video and hasattr(output, "video_audios") and output.video_audios: + for idx, video_audio_chunks in enumerate(output.video_audios): + if video_audio_chunks: + processed_audio_values = [] + processed_audio_masks = [] + chunk_output_lengths = [] + + for chunk in video_audio_chunks: + if isinstance(chunk, torch.Tensor): + chunk_np = chunk.cpu().numpy() + else: + chunk_np = chunk + + preprocess_results = self.prepare_audio_input_fn( + [chunk_np], + sampling_rate=self.prepare_audio_input_fn.sampling_rate, + return_attention_mask=True, + padding="max_length", + ) + + audio_value = preprocess_results.input_features[0] + audio_mask = preprocess_results.attention_mask[0] + + mask_sum = int(audio_mask.sum()) + input_lengths = (mask_sum - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + chunk_output_lengths.append(output_lengths) + + processed_audio_values.append(torch.from_numpy(audio_value)) + processed_audio_masks.append(torch.from_numpy(audio_mask)) + + pool_size = 25 + if self.video_audio_compressor_type is not None: + total_valid_len = sum(chunk_output_lengths) + total_audio_query_length = ( + total_valid_len + pool_size - 1 + ) // pool_size + else: + total_audio_query_length = sum( + (valid_len + pool_size - 1) // pool_size + for valid_len in chunk_output_lengths + ) + + video_audio_values.append(processed_audio_values) + video_audio_masks.append(processed_audio_masks) + video_audio_query_lengths.append(total_audio_query_length) + + import os + + if ( + int(os.environ.get("RANK", -1)) == 0 + and total_audio_query_length == 177 + ): + print( + f"\n[PREPROCESSOR VIDEO - 177 TOKENS DETECTED!] total_audio_query_length={total_audio_query_length}, num_chunks={len(processed_audio_masks)}" + ) + for chunk_idx, mask_tensor in enumerate(processed_audio_masks): + chunk_mask_sum = int(mask_tensor.sum()) + chunk_input_len = (chunk_mask_sum - 1) // 2 + 1 + chunk_output_len = (chunk_input_len - 2) // 2 + 1 + chunk_pooled = (chunk_output_len + 24) // 25 + print( + f" Chunk {chunk_idx}: mask_sum={chunk_mask_sum}, output_len={chunk_output_len}, pooled={chunk_pooled}" + ) + print() + + else: + video_audio_values.append([]) + video_audio_masks.append([]) + video_audio_query_lengths.append(0) + + dummy_video_preprocess_results = self.prepare_input_fn.video_processor( + [Image.new("RGB", (224, 224), (0, 0, 0))] * 3 + ) + dummy_pixel_values_videos = torch.from_numpy( + np.concatenate([dummy_video_preprocess_results.pixel_values_videos], axis=0) + ) + dummy_video_grid_thw = torch.from_numpy( + np.concatenate([dummy_video_preprocess_results.video_grid_thw], axis=0) + ) + dummy_video_preprocess_results = self.prepare_audio_input_fn( + [np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], + sampling_rate=self.prepare_audio_input_fn.sampling_rate, + return_attention_mask=True, + padding="max_length", + ) + dummy_video_audio_values = torch.from_numpy( + dummy_video_preprocess_results.input_features + ) + dummy_video_audio_masks = torch.from_numpy( + dummy_video_preprocess_results.attention_mask + ) + + audio_values = None + discrete_audio_values = None + audio_masks = None + dummy_preprocess_results = self.prepare_audio_input_fn( + [np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], + sampling_rate=self.prepare_audio_input_fn.sampling_rate, + return_attention_mask=True, + padding="max_length", + ) + dummy_audio_values = torch.from_numpy(dummy_preprocess_results.input_features) + dummy_audio_masks = torch.from_numpy(dummy_preprocess_results.attention_mask) + if self.train_audio: + audio_values = [] + discrete_audio_values = [] + audio_masks = [] + audio_query_lengths = [] + discrete_audio_query_lengths = [] + + if len(output.audios) > 99: + raise ConditionalError( + f"Too many audio segments in one sample: {len(output.audios)} audios." + ) + + for audio in output.audios: + chunks = [] + for i in range( + 0, len(audio), 30 * self.prepare_audio_input_fn.sampling_rate + ): + chunks.append( + audio[i : i + 30 * self.prepare_audio_input_fn.sampling_rate] + ) + num_of_chunks = len(chunks) + preprocess_results = self.prepare_audio_input_fn( + chunks, + sampling_rate=self.prepare_audio_input_fn.sampling_rate, + return_attention_mask=True, + padding="max_length", + ) + audio_value = preprocess_results.input_features + audio_mask = preprocess_results.attention_mask + audio_values.append(audio_value) + audio_masks.append(audio_mask) + input_lengths = int(audio_mask.sum()) + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + audio_query_lengths.append(output_lengths) + + if len(output.audios) == 0: + audio_values = torch.zeros(0, 128, 3000) + audio_masks = torch.zeros(0, 3000) + else: + audio_values = torch.from_numpy(np.concatenate(audio_values, axis=0)) + audio_masks = torch.from_numpy(np.concatenate(audio_masks, axis=0)) + + for audio in output.discrete_audios: + audio_length = len(audio) + + assert audio_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( + f"discrete_audio is too short ({audio_length} samples < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}). " + f"This will cause 0-dim/empty tensor in CosyVoice encoder. " + f"Skip this sample." + ) + + max_audio_length = 600 * DEFAULT_SAMPLE_RATE + audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE + assert ( + audio_length <= max_audio_length + ), f"discrete_audio is too long ({audio_length} samples = {audio_duration_sec:.1f}s > 600s). " + + assert not torch.isnan(audio).any(), ( + f"discrete_audio contains NaN values! " + f"This will cause CUDA illegal memory access. Skip this sample." + ) + assert not torch.isinf(audio).any(), ( + f"discrete_audio contains Inf values! " + f"This will cause CUDA illegal memory access. Skip this sample." + ) + + audio_min, audio_max = audio.min().item(), audio.max().item() + assert -100.0 <= audio_min <= 100.0 and -100.0 <= audio_max <= 100.0, ( + f"discrete_audio has extreme values (min={audio_min:.2f}, max={audio_max:.2f}). " + f"Expected roughly [-1, 1] range. This indicates corrupted audio. Skip this sample." + ) + + discrete_audio_values.append(audio) + + if audio_length > 80 * DEFAULT_SAMPLE_RATE: + chunk_size = 80 * DEFAULT_SAMPLE_RATE + + total_code_len = 0 + + for start in range(0, audio_length, chunk_size): + end = min(start + chunk_size, audio_length) + + if ( + end < audio_length + and audio_length - end < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES + ): + end = audio_length + + chunk_length = end - start + + assert chunk_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( + f"chunk_length={chunk_length} < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}. This should never happen with our chunking logic. " + f"audio_length={audio_length}, start={start}, end={end}. Skip this sample." + ) + + mel_len = chunk_length // 160 + + assert mel_len > 0, ( + f"mel_len={mel_len} is invalid (chunk_length={chunk_length}). " + f"This will cause illegal memory access in AudioEncoder. Skip this sample." + ) + + after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 + code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 + + assert code_len > 0, ( + f"code_len={code_len} is invalid (mel_len={mel_len}, after_conv1={after_conv1}). " + f"This will cause illegal memory access. Skip this sample." + ) + + total_code_len += code_len + + if end >= audio_length: + break + + assert total_code_len > 0, ( + f"total_code_len={total_code_len} is invalid after processing all chunks. " + f"audio_length={audio_length}. This should never happen. Skip this sample." + ) + + audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE + max_expected_codes = int(audio_duration_sec * 25 * 1.1) + assert total_code_len <= max_expected_codes, ( + f"total_code_len={total_code_len} is suspiciously large (max_expected={max_expected_codes}). " + f"audio_length={audio_length} ({audio_duration_sec:.1f}s). " + f"Expected ~{int(audio_duration_sec * 25)} tokens (25 tokens/sec). " + f"This indicates calculation error. Skip this sample." + ) + + discrete_audio_query_lengths.append(total_code_len) + else: + mel_len = audio_length // 160 + + assert mel_len > 0, ( + f"mel_len={mel_len} is invalid (audio_length={audio_length}). " + f"This will cause illegal memory access in AudioEncoder. Skip this sample." + ) + + after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 + code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 + + assert code_len > 0, ( + f"Calculated code_len={code_len} is invalid (audio_length={audio_length}, " + f"mel_len={mel_len}, after_conv1={after_conv1}). " + f"This indicates corrupted audio data. Skip this sample." + ) + + assert code_len <= 2048, ( + f"code_len={code_len} exceeds freqs_cis max length (2048). " + f"Audio length: {audio_length / DEFAULT_SAMPLE_RATE:.1f}s (max ~82s for single chunk). " + f"Expected ~{int((audio_length / DEFAULT_SAMPLE_RATE) * 25)} tokens at 25 tokens/sec. " + f"This will cause illegal memory access in apply_rotary_emb. Skip this sample." + ) + + discrete_audio_query_lengths.append(code_len) + + img_start_ids = [ + i for i, token in enumerate(input_ids) if token == self.img_token + ] + assert len(img_start_ids) == len(mm_query_lengths) + for i, length in zip( + range(len(mm_query_lengths) - 1, -1, -1), mm_query_lengths[::-1] + ): + labels[img_start_ids[i] : img_start_ids[i] + 1] = [IGNORE_INDEX] * length + input_ids[img_start_ids[i] : img_start_ids[i] + 1] = [ + self.img_token + ] * length + total_mm_query_length += length + + discrete_image_start_ids = [ + i for i, token in enumerate(input_ids) if token == self.discrete_image_token + ] + assert len(discrete_image_start_ids) == len(discrete_image_query_lengths) + assert len(discrete_image_start_ids) == len( + image_ratios + ), "discrete_image_start_ids and image_ratios length mismatch" + + for idx in range(len(discrete_image_query_lengths) - 1, -1, -1): + i = discrete_image_start_ids[idx] + length = discrete_image_query_lengths[idx] + ratio_token_id = image_ratios[idx] + assert ( + length == 729 + ), f"discrete_image_query_length must be 729, but got {length}" + + token_sequence = [ratio_token_id] + for token_idx in range(length): + token_sequence.append(self.discrete_image_token) + if (token_idx + 1) % 27 == 0: + token_sequence.append(self.discrete_image_eol_token) + token_sequence.append(self.discrete_image_eof_token) + + total_length = len(token_sequence) + if labels[i] == IGNORE_INDEX: + labels[i : i + 1] = [IGNORE_INDEX] * total_length + else: + labels[i : i + 1] = token_sequence + input_ids[i : i + 1] = token_sequence + + if self.train_video: + vid_start_ids = [ + i for i, token in enumerate(input_ids) if token == self.video_token + ] + + for idx in range(len(vid_start_ids) - 1, -1, -1): + pos = vid_start_ids[idx] + + num_frames = int(video_grid_thw[idx][0]) + frame_query_length = video_query_lengths[idx] + + has_video_audio = ( + idx < len(video_audio_query_lengths) + and video_audio_query_lengths[idx] > 0 + ) + + if has_video_audio: + total_audio_tokens = video_audio_query_lengths[idx] + + token_sequence = [] + + if num_frames > 0: + + frame_base = frame_query_length // num_frames + frame_remainder = frame_query_length % num_frames + + assert frame_remainder == 0, ( + f"frame_query_length({frame_query_length}) must be divisible by num_frames({num_frames}). " + f"Each frame produces fixed number of tokens. Got remainder={frame_remainder}." + ) + + audio_base = total_audio_tokens // num_frames + audio_remainder = total_audio_tokens % num_frames + + for frame_idx in range(num_frames): + frame_tokens = frame_base + ( + 1 if frame_idx < frame_remainder else 0 + ) + token_sequence.extend([self.video_token] * frame_tokens) + + audio_tokens = audio_base + ( + 1 if frame_idx < audio_remainder else 0 + ) + if audio_tokens > 0: + token_sequence.extend( + [self.video_audio_token] * audio_tokens + ) + else: + token_sequence = [self.video_token] * frame_query_length + else: + token_sequence = [self.video_token] * frame_query_length + + total_length = len(token_sequence) + labels[pos : pos + 1] = [IGNORE_INDEX] * total_length + input_ids[pos : pos + 1] = token_sequence + + if self.train_audio: + audio_start_ids = [ + i for i, token in enumerate(input_ids) if token == self.audio_token + ] + assert len(audio_start_ids) == len(audio_query_lengths) + for i, length in zip( + range(len(audio_query_lengths) - 1, -1, -1), audio_query_lengths[::-1] + ): + labels[audio_start_ids[i] : audio_start_ids[i] + 1] = [ + IGNORE_INDEX + ] * length + input_ids[audio_start_ids[i] : audio_start_ids[i] + 1] = [ + self.audio_token + ] * length + + discrete_audio_start_ids = [ + i + for i, token in enumerate(input_ids) + if token == self.discrete_audio_token + ] + + assert len(discrete_audio_start_ids) == len(discrete_audio_query_lengths), ( + f"discrete_audio_start_ids count ({len(discrete_audio_start_ids)}) != " + f"discrete_audio_query_lengths count ({len(discrete_audio_query_lengths)}). " + f"This indicates a serious bug in preprocessor or corrupted data. Skip this sample." + ) + + for i, length in zip( + range(len(discrete_audio_query_lengths) - 1, -1, -1), + discrete_audio_query_lengths[::-1], + ): + assert 0 < length < 16000, ( + f"discrete_audio_query_length={length} is out of valid range [1, 16000). " + f"Expected max ~15,000 for 600s audio at 25 tokens/sec. " + f"This can cause illegal memory access when creating embeddings. Skip this sample." + ) + + if labels[discrete_audio_start_ids[i]] == IGNORE_INDEX: + labels[ + discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 + ] = [IGNORE_INDEX] * length + else: + labels[ + discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 + ] = [self.discrete_audio_token] * length + input_ids[ + discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 + ] = [self.discrete_audio_token] * length + + if self.sequence_parallel_size > 1: + if len(input_ids) % self.sequence_parallel_size != 0: + input_ids += [self.tokenizer.pad_token_id] * ( + self.sequence_parallel_size + - (len(input_ids) % self.sequence_parallel_size) + ) + labels += [IGNORE_INDEX] * ( + self.sequence_parallel_size + - (len(labels) % self.sequence_parallel_size) + ) + + if not is_sft1: + input_ids = torch.tensor(input_ids) + labels = torch.tensor(labels) + + if self.mode == "train": + if self.sample_min_length is not None and self.sample_min_length > 0: + assert ( + len(labels) >= self.sample_min_length + ), "The sample is too short: {} < {}".format( + len(labels), self.sample_min_length + ) + assert ( + len(labels) <= self.decoder_max_length + ), "The sample exceeds decoder_max_len: {} > {}".format( + len(labels), self.decoder_max_length + ) + assert len(input_ids) == len(labels) + + if len(labels) < 30: + raise ConditionalError( + "The sample is too short: {}".format(len(labels)) + ) + + if torch.all(labels == IGNORE_INDEX): + raise ConditionalError( + "Labels contain only IGNORE_INDEX, no training targets available" + ) + + sample = { + "pixel_values": pixel_values, + "discrete_pixel_values": discrete_pixel_values, + "idx_for_debug": idx_for_debug, + "input_ids": input_ids, + "labels": labels, + "queries": queries if len(queries) > 0 else None, + "gts": gts if len(gts) > 0 else None, + "mm_query_lengths": mm_query_lengths, + "non_mm_query_lengths": len(labels) - total_mm_query_length, + "total_length": len(labels), + "data_name": config["data_name"], + "data_type": config["data_type"], + "img_start_ids": img_start_ids, + "prompt": output.input_str, + "options": config.get("options", None), + "image_grid_thw": image_grid_thw, + "pixel_values_videos": pixel_values_videos, + "video_grid_thw": video_grid_thw, + "video_audio_values": ( + video_audio_values if len(video_audio_values) > 0 else None + ), + "video_audio_masks": ( + video_audio_masks if len(video_audio_masks) > 0 else None + ), + "audio_values": audio_values, + "discrete_audio_values": discrete_audio_values, + "audio_masks": audio_masks, + "dummy_pixel_values": dummy_pixel_values, + "dummy_grid_thw": dummy_grid_thw, + "dummy_audio_values": dummy_audio_values, + "dummy_audio_masks": dummy_audio_masks, + "dummy_pixel_values_videos": dummy_pixel_values_videos, + "dummy_video_grid_thw": dummy_video_grid_thw, + "dummy_video_audio_values": dummy_video_audio_values, + "dummy_video_audio_masks": dummy_video_audio_masks, + } + + return sample + + def _sampling_multiturn( + self, + turns, + n_sample, + multiturn_preserve_order=True, + multiturn_continuous=False, + ): + new_turns = [] + sample_indices = [] + first_user_turn = True + start_idx = 0 + for idx, turn in enumerate(turns): + if turn["role"] in ["system", "tool_list"]: + new_turns.append(turn) + start_idx = idx + 1 + continue + if turn["role"] == "user": + image_nums = re.findall(r"", turn["content"]) + if len(image_nums) == 0: + image_nums = re.findall(r"<\|image\|>", turn["content"]) + if len(image_nums) > 0: + if first_user_turn: + first_user_turn = False + continue + sample_indices.append([i for i in range(start_idx, idx)]) + start_idx = idx + sample_indices.append([i for i in range(start_idx, idx + 1)]) + n_sample = min(n_sample, len(sample_indices)) + if multiturn_continuous: + start_index = random.randint(0, len(sample_indices) - n_sample) + indices = range(start_index, start_index + n_sample) + elif multiturn_preserve_order: + indices = sorted(random.sample(range(len(sample_indices)), n_sample)) + else: + indices = random.sample(range(len(sample_indices)), n_sample) + sampled_indices = [sample_indices[i] for i in indices] + new_turns = new_turns + [ + turns[i] for sampled_turns in sampled_indices for i in sampled_turns + ] + return new_turns