|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import json |
|
|
import logging |
|
|
import numpy as np |
|
|
import os |
|
|
import os.path |
|
|
import os.path as osp |
|
|
import shutil |
|
|
import warnings |
|
|
from abc import ABC |
|
|
from collections import OrderedDict, defaultdict, deque |
|
|
from copy import deepcopy |
|
|
from itertools import chain |
|
|
from threading import Thread |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchvision |
|
|
from einops import rearrange |
|
|
from PIL import Image |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModel, |
|
|
AutoProcessor, |
|
|
AutoTokenizer, |
|
|
GenerationConfig, |
|
|
LogitsProcessor, |
|
|
PretrainedConfig, |
|
|
PreTrainedModel, |
|
|
Qwen2Config, |
|
|
Qwen2ForCausalLM, |
|
|
Qwen2PreTrainedModel, |
|
|
TextIteratorStreamer, |
|
|
WhisperFeatureExtractor, |
|
|
) |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.modeling_utils import ContextManagers, no_init_weights |
|
|
|
|
|
from .auto_processor import VILAProcessor |
|
|
from .base_projector import MultimodalProjector, MultimodalProjectorConfig |
|
|
from .sound_base_projector import SoundMultimodalProjector, SoundMultimodalProjectorConfig |
|
|
from .speech_base_projector import SpeechMultimodalProjector, SpeechMultimodalProjectorConfig |
|
|
|
|
|
from .builder import build_llm_and_tokenizer |
|
|
from .configuration_vila import VILAConfig |
|
|
from .constants import * |
|
|
from .conversation import SeparatorStyle, default_conversation |
|
|
from .distributed import all_gather as vila_all_gather |
|
|
from .media import extract_media |
|
|
from .media_encoder import BasicImageEncoder, BasicVideoEncoder, TSPVideoEncoder, BasicSoundEncoder, CacheFeatures |
|
|
from .mm_utils import process_image, process_images |
|
|
from .model_utils_packing import set_seqlens_in_batch |
|
|
from .siglip_encoder import SiglipVisionTower, SiglipVisionTowerDynamicS2, SiglipVisionTowerS2 |
|
|
from .tokenizer_utils import tokenize_conversation |
|
|
from .utils import get_model_config, load_tokenizer_then_handle_media_tokens_and_chat_template |
|
|
|
|
|
from .constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, NUM_EXTRA_TOKENS_VILA, NUM_EXTRA_TOKENS_XVILA |
|
|
from .qwen_audio_encoder import Qwen2AudioTower |
|
|
import whisper |
|
|
|
|
|
from .audio_encoder import AudioTower |
|
|
|
|
|
|
|
|
def build_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: |
|
|
"""Build multimodal projector from path or configuration.""" |
|
|
if model_type_or_path is None: |
|
|
return None |
|
|
if config.resume_path: |
|
|
assert os.path.exists(model_type_or_path), f"Resume mm projector path {model_type_or_path} does not exist!" |
|
|
return MultimodalProjector.from_pretrained(model_type_or_path, config) |
|
|
else: |
|
|
mm_projector_cfg = MultimodalProjectorConfig(model_type_or_path) |
|
|
mm_projector = MultimodalProjector(mm_projector_cfg, config) |
|
|
return mm_projector |
|
|
|
|
|
def build_speech_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: |
|
|
"""Build speech multimodal projector from path or configuration.""" |
|
|
if model_type_or_path is None: |
|
|
return None |
|
|
if config.resume_path: |
|
|
assert os.path.exists(model_type_or_path), f"Resume speech mm projector path {model_type_or_path} does not exist!" |
|
|
_model = SpeechMultimodalProjector.from_pretrained( |
|
|
model_type_or_path, config, torch_dtype=eval(config.model_dtype) |
|
|
) |
|
|
return _model |
|
|
else: |
|
|
speech_mm_projector_cfg = SpeechMultimodalProjectorConfig(model_type_or_path) |
|
|
speech_mm_projector = SpeechMultimodalProjector(speech_mm_projector_cfg, config).to(eval(config.model_dtype)) |
|
|
return speech_mm_projector |
|
|
|
|
|
|
|
|
def build_sound_mm_projector(model_type_or_path: str, config: PretrainedConfig) -> PreTrainedModel: |
|
|
"""Build sound multimodal projector from path or configuration.""" |
|
|
if model_type_or_path is None: |
|
|
return None |
|
|
|
|
|
if type(config.model_dtype) == str: |
|
|
model_dtype = eval(config.model_dtype) |
|
|
else: |
|
|
model_dtype = config.model_dtype |
|
|
if config.resume_path: |
|
|
assert os.path.exists(model_type_or_path), f"Resume sound mm projector path {model_type_or_path} does not exist!" |
|
|
_model = SoundMultimodalProjector.from_pretrained( |
|
|
model_type_or_path, config, torch_dtype=model_dtype |
|
|
) |
|
|
return _model |
|
|
else: |
|
|
sound_mm_projector_cfg = SoundMultimodalProjectorConfig(model_type_or_path) |
|
|
sound_mm_projector = SoundMultimodalProjector(sound_mm_projector_cfg, config).to(model_dtype) |
|
|
return sound_mm_projector |
|
|
|
|
|
|
|
|
def check_dot_in_model_path(model_path: str): |
|
|
"""Check if the model path contains a dot, which may affect model loading.""" |
|
|
if osp.isdir(model_path): |
|
|
if "." in osp.abspath(model_path): |
|
|
return True |
|
|
else: |
|
|
if "." in model_path: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def get_vila_version(model_path: str) -> str: |
|
|
VERSIONS = ["vila1.5", "vila-u", "longvila", "nvila", "vila-m3"] |
|
|
for version in VERSIONS: |
|
|
if version in model_path.lower(): |
|
|
return version |
|
|
return None |
|
|
|
|
|
|
|
|
def generate_jinja_template(conv_mode: str) -> str: |
|
|
if conv_mode == "vicuna_v1": |
|
|
return """{% set system_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. " %} |
|
|
{% set roles = ["user", "assistant"] %} |
|
|
{% set sep = " " %} |
|
|
|
|
|
{{ system_prompt }} |
|
|
|
|
|
{% for message in messages %} |
|
|
{% if message['role'] == roles[0] %} |
|
|
{{ "USER: " }}{{ sep }}{{ message['content'] }}{{ sep }} |
|
|
{% else %} |
|
|
{{ "ASSISTANT: " }}{{ sep }}{{ message['content'] }}{{ sep }} |
|
|
{% endif %} |
|
|
{% endfor %} |
|
|
{% if messages[-1]['role'] == 'user' %} |
|
|
{{ "ASSISTANT:" }} |
|
|
{% endif %} |
|
|
""" |
|
|
elif conv_mode == "llama_3": |
|
|
return """{% set system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|>" %} |
|
|
{% set roles = ["<|start_header_id|>user<|end_header_id|>\\n\\n", "<|start_header_id|>assistant<|end_header_id|>\\n\\n"]%} |
|
|
{% set sep = "<|eot_id|>" %} |
|
|
|
|
|
{{ system_prompt }} |
|
|
{% for message in messages %} |
|
|
{% if message['role'] == 'user' %} |
|
|
{{ roles[0] }}{{ message['content'] }}{{ sep }} |
|
|
{% else %} |
|
|
{{ roles[1] }}{{ message['content'] }}{{ sep }} |
|
|
{% endif %} |
|
|
{% endfor %} |
|
|
{% if messages[-1]['role'] == 'user' %} |
|
|
{{ roles[1] }} |
|
|
{% endif %} |
|
|
""" |
|
|
elif conv_mode == "hermes_2": |
|
|
return """{% set system_prompt = "<|im_start|>system\nAnswer the questions." %} |
|
|
{% set roles = ["<|im_start|>user\n", "<|im_start|>assistant\n"] %} |
|
|
{% set sep = "<|im_end|>" %} |
|
|
|
|
|
{{ system_prompt }}{{ sep }} |
|
|
|
|
|
{% for message in messages %} |
|
|
{% if message['role'] == 'user' %} |
|
|
{{ roles[0] }}{{ message['content'] }}{{ sep }} |
|
|
{% else %} |
|
|
{{ roles[1] }}{{ message['content'] }}{{ sep }} |
|
|
{% endif %} |
|
|
{% endfor %}""" |
|
|
else: |
|
|
raise NotImplementedError(f"Jinja template generation is not implemented for {conv_mode}.") |
|
|
|
|
|
|
|
|
def build_vision_tower(model_name_or_path: str, config: PretrainedConfig) -> PreTrainedModel: |
|
|
"""Build vision tower from path or configuration.""" |
|
|
|
|
|
if model_name_or_path is None: |
|
|
return None |
|
|
|
|
|
vision_tower_arch = None |
|
|
if config.resume_path and "radio" not in model_name_or_path: |
|
|
assert os.path.exists(model_name_or_path), f"Resume vision tower path {model_name_or_path} does not exist!" |
|
|
vision_tower_cfg = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
|
vision_tower_arch = vision_tower_cfg.architectures[0].lower() |
|
|
vision_tower_name = vision_tower_arch if vision_tower_arch is not None else model_name_or_path |
|
|
|
|
|
use_s2 = getattr(config, "s2", False) |
|
|
use_dynamic_s2 = getattr(config, "dynamic_s2", False) |
|
|
|
|
|
if "siglip" in vision_tower_name: |
|
|
if use_dynamic_s2: |
|
|
vision_tower = SiglipVisionTowerDynamicS2(model_name_or_path, config) |
|
|
elif use_s2: |
|
|
vision_tower = SiglipVisionTowerS2(model_name_or_path, config) |
|
|
else: |
|
|
vision_tower = SiglipVisionTower(model_name_or_path, config) |
|
|
else: |
|
|
raise NotImplementedError(f"Unknown vision tower: {model_name_or_path}") |
|
|
|
|
|
config.mm_hidden_size = ( |
|
|
vision_tower.config.hidden_size if not (use_s2 or use_dynamic_s2) else vision_tower.hidden_size |
|
|
) |
|
|
return vision_tower |
|
|
|
|
|
|
|
|
def build_audio_tower(model_name_or_path: str, config: PretrainedConfig, encoder_type: str) -> PreTrainedModel: |
|
|
"""Build audio tower for sound or speech processing.""" |
|
|
assert encoder_type in ["sound", "speech"] |
|
|
|
|
|
|
|
|
if model_name_or_path is None: |
|
|
return None |
|
|
|
|
|
model_type = "af3" |
|
|
|
|
|
if model_type == "af3": |
|
|
model = Qwen2AudioTower(model_name_or_path, config) |
|
|
output_dim = 1280 |
|
|
else: |
|
|
raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}") |
|
|
|
|
|
if encoder_type == "sound": |
|
|
config.sound_hidden_size = output_dim |
|
|
elif encoder_type == "speech": |
|
|
config.speech_hidden_size = output_dim |
|
|
else: |
|
|
raise NotImplementedError(f"Not implemented for this encoder: {model_name_or_path}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class VILAPretrainedModel(PreTrainedModel): |
|
|
config_class = VILAConfig |
|
|
main_input_name = "input_embeds" |
|
|
supports_gradient_checkpointing = True |
|
|
_supports_flash_attn_2 = True |
|
|
_no_split_modules = ["Qwen2DecoderLayer", "SiglipEncoderLayer"] |
|
|
|
|
|
def __init__(self, config: VILAConfig, *args, **kwargs): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
cfgs = get_model_config(config) |
|
|
|
|
|
if len(cfgs) == 7: |
|
|
( |
|
|
llm_cfg, |
|
|
vision_tower_cfg, |
|
|
speech_tower_cfg, |
|
|
sound_tower_cfg, |
|
|
mm_projector_cfg, |
|
|
speech_mm_projector_cfg, |
|
|
sound_mm_projector_cfg, |
|
|
) = cfgs |
|
|
else: |
|
|
raise ValueError( |
|
|
"`llm_cfg` `mm_projector_cfg` `speech_mm_projector_cfg` `sound_mm_projector_cfg` `vision_tower_cfg` `speech_tower_cfg` `sound_tower_cfg` not found in the config." |
|
|
) |
|
|
|
|
|
|
|
|
device_map = kwargs.get("device_map", "auto") |
|
|
self.mm_projector = build_mm_projector(mm_projector_cfg, config) |
|
|
self.vision_tower = build_vision_tower(vision_tower_cfg, config) |
|
|
|
|
|
if speech_tower_cfg: |
|
|
self.speech_tower = build_audio_tower(speech_tower_cfg, config, encoder_type="speech") |
|
|
self.speech_mm_projector = build_speech_mm_projector(speech_mm_projector_cfg, config) |
|
|
if sound_tower_cfg: |
|
|
self.sound_tower = build_audio_tower(sound_tower_cfg, config, encoder_type="sound") |
|
|
self.sound_mm_projector = build_sound_mm_projector(sound_mm_projector_cfg, config) |
|
|
|
|
|
|
|
|
if device_map in ["auto", "cuda"]: |
|
|
self.mm_projector = self.mm_projector.cuda() |
|
|
self.vision_tower = self.vision_tower.cuda() |
|
|
self.speech_tower = self.speech_tower.cuda() if hasattr(self, "speech_tower") else None |
|
|
self.sound_tower = self.sound_tower.cuda() if hasattr(self, "sound_tower") else None |
|
|
self.speech_mm_projector = self.speech_mm_projector.cuda() if hasattr(self, "speech_mm_projector") else None |
|
|
self.sound_mm_projector = self.sound_mm_projector.cuda() if hasattr(self, "sound_mm_projector") else None |
|
|
|
|
|
self.llm, self.tokenizer = self.init_llm(llm_cfg, config, device_map=device_map) |
|
|
|
|
|
self.llm_model_embed_tokens = self.llm.model.embed_tokens |
|
|
|
|
|
self.tokenizer.padding_side = "left" |
|
|
|
|
|
self.vocab_size = len(self.tokenizer) |
|
|
self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer)) |
|
|
|
|
|
self.encoders = {} |
|
|
for name in ["image", "video", "speech", "sound"]: |
|
|
encoder_config = getattr(self.config, f"{name}_encoder") |
|
|
if isinstance(encoder_config, str): |
|
|
encoder_config = json.loads(encoder_config) |
|
|
if encoder_config.get("embed_time", False) == "True": |
|
|
if "trope_dim" not in encoder_config and encoder_config.get("time_embed_type", "") in ["pixel", "lang"]: |
|
|
encoder_config["trope_dim"] = self.config.hidden_size // 2 |
|
|
print(f"Warning: trope_dim not found in config, defaulting to hidden_size // 2: {encoder_config['trope_dim']}") |
|
|
|
|
|
encoder_config.pop('_target_') |
|
|
if name == "video": |
|
|
self.encoders[name] = TSPVideoEncoder(parent=self, **encoder_config) |
|
|
elif name == "image": |
|
|
self.encoders[name] = BasicImageEncoder(self) |
|
|
else: |
|
|
self.encoders[name] = BasicSoundEncoder(parent=self, **encoder_config) |
|
|
|
|
|
self.post_config() |
|
|
self.is_loaded = True |
|
|
|
|
|
self.llm_only_need_embed = kwargs.get("llm_only_need_embed", False) |
|
|
if self.llm_only_need_embed: |
|
|
print("We only need the embed_tokens in llm.") |
|
|
del self.llm |
|
|
self.llm = None |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
assert ( |
|
|
self.llm is not None |
|
|
or self.vision_tower is not None |
|
|
or self.speech_tower is not None |
|
|
or self.mm_projector is not None |
|
|
or self.speech_mm_projector is not None |
|
|
), "At least one of the components must be instantiated." |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def copy_or_symlink_directory(cls, model_path, output_dir, copy=True): |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
for item in os.listdir(model_path): |
|
|
src_path = os.path.join(model_path, item) |
|
|
dst_path = os.path.join(output_dir, item) |
|
|
|
|
|
|
|
|
if os.path.exists(dst_path): |
|
|
if os.path.islink(dst_path): |
|
|
os.unlink(dst_path) |
|
|
elif os.path.isdir(dst_path): |
|
|
shutil.rmtree(dst_path) |
|
|
else: |
|
|
os.remove(dst_path) |
|
|
|
|
|
|
|
|
if copy: |
|
|
if os.path.isdir(src_path): |
|
|
shutil.copytree(src_path, dst_path) |
|
|
else: |
|
|
shutil.copy2(src_path, dst_path) |
|
|
print(f"Copied {src_path} to {dst_path}") |
|
|
else: |
|
|
os.symlink(src_path, dst_path) |
|
|
print(f"Created symlink from {src_path} to {dst_path}") |
|
|
|
|
|
@classmethod |
|
|
def copy_remote_py_files(cls, output_dir, copy=True): |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
current_folder = os.path.dirname(current_file_path) |
|
|
for file_name in os.listdir(current_folder): |
|
|
if file_name == "INSTRUCTIONS.md": |
|
|
src_fname = os.path.join(current_folder, file_name) |
|
|
dst_fname = os.path.join(output_dir, "README.md") |
|
|
if os.path.exists(dst_fname): |
|
|
old_readme = open(dst_fname).read() |
|
|
else: |
|
|
old_readme = "" |
|
|
with open(src_fname) as src, open(dst_fname, "w") as dst: |
|
|
dst.write(src.read()) |
|
|
dst.write(old_readme) |
|
|
print("[HF] README", src_fname, "to", dst_fname) |
|
|
if file_name.endswith(".py") or file_name.endswith(".jinja"): |
|
|
full_file_name = os.path.join(current_folder, file_name) |
|
|
if os.path.isfile(full_file_name): |
|
|
if copy: |
|
|
shutil.copy(full_file_name, output_dir) |
|
|
print("[HF] copying", full_file_name, "to", output_dir) |
|
|
else: |
|
|
|
|
|
if os.path.exists(os.path.join(output_dir, file_name)): |
|
|
os.remove(os.path.join(output_dir, file_name)) |
|
|
os.symlink(full_file_name, os.path.join(output_dir, file_name)) |
|
|
print("[HF] linking", full_file_name, "to", output_dir) |
|
|
|
|
|
def save_pretrained(self, output_dir, state_dict=None, **kwargs): |
|
|
if state_dict is None: |
|
|
|
|
|
|
|
|
state_dict = self.state_dict() |
|
|
|
|
|
if getattr(self, "tokenizer", None): |
|
|
self.tokenizer.save_pretrained(osp.join(output_dir, "llm")) |
|
|
|
|
|
if self.get_llm(): |
|
|
print(f"saving llm to {osp.join(output_dir, 'llm')}") |
|
|
self.llm.config._name_or_path = osp.join(output_dir, "llm") |
|
|
llm_state_dict = OrderedDict({k.split("llm.")[-1]: v for k, v in state_dict.items() if "llm" in k}) |
|
|
self.llm.save_pretrained(os.path.join(output_dir, "llm"), state_dict=llm_state_dict) |
|
|
self.config.llm_cfg = self.llm.config |
|
|
|
|
|
if self.get_vision_tower(): |
|
|
print(f"saving vision_tower to {osp.join(output_dir, 'vision_tower')}") |
|
|
self.vision_tower.config._name_or_path = osp.join(output_dir, "vision_tower") |
|
|
vision_tower_state_dict = OrderedDict( |
|
|
{k.split("vision_tower.vision_tower.")[-1]: v for k, v in state_dict.items() if "vision_tower" in k} |
|
|
) |
|
|
self.vision_tower.vision_tower.save_pretrained( |
|
|
os.path.join(output_dir, "vision_tower"), |
|
|
state_dict=vision_tower_state_dict, |
|
|
) |
|
|
self.vision_tower.image_processor.save_pretrained(os.path.join(output_dir, "vision_tower")) |
|
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
|
if hasattr(self.config.vision_tower_cfg, "auto_map"): |
|
|
if "radio" not in self.get_vision_tower().__class__.__name__.lower(): |
|
|
delattr(self.config.vision_tower_cfg, "auto_map") |
|
|
if self.get_speech_tower(): |
|
|
print(f"saving speech_tower to {osp.join(output_dir, 'speech_tower')}") |
|
|
self.speech_tower.config._name_or_path = osp.join(output_dir, "speech_tower").replace( |
|
|
"tmp-checkpoint", "checkpoint" |
|
|
) |
|
|
|
|
|
speech_tower_state_dict = OrderedDict( |
|
|
{k.split("speech_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "speech_tower" in k} |
|
|
) |
|
|
|
|
|
self.speech_tower.audio_tower.save_pretrained( |
|
|
os.path.join(output_dir, "speech_tower"), |
|
|
state_dict=speech_tower_state_dict, |
|
|
) |
|
|
self.config.speech_tower_cfg = self.speech_tower.config |
|
|
|
|
|
if self.get_sound_tower(): |
|
|
print(f"saving sound_tower to {osp.join(output_dir, 'sound_tower')}") |
|
|
self.sound_tower.config._name_or_path = osp.join(output_dir, "sound_tower").replace( |
|
|
"tmp-checkpoint", "checkpoint" |
|
|
) |
|
|
|
|
|
sound_tower_state_dict = OrderedDict( |
|
|
{k.split("sound_tower.audio_tower.")[-1]: v for k, v in state_dict.items() if "sound_tower" in k} |
|
|
) |
|
|
|
|
|
self.sound_tower.audio_tower.save_pretrained( |
|
|
os.path.join(output_dir, "sound_tower"), |
|
|
state_dict=sound_tower_state_dict, |
|
|
) |
|
|
self.config.sound_tower_cfg = self.sound_tower.config |
|
|
|
|
|
|
|
|
|
|
|
if self.get_mm_projector(): |
|
|
print(f"saving mm_projector to {osp.join(output_dir, 'mm_projector')}") |
|
|
self.mm_projector.config._name_or_path = osp.join(output_dir, "mm_projector") |
|
|
mm_projector_state_dict = OrderedDict( |
|
|
{k.split("mm_projector.")[-1]: v for k, v in state_dict.items() if "mm_projector" in k} |
|
|
) |
|
|
self.mm_projector.save_pretrained( |
|
|
os.path.join(output_dir, "mm_projector"), |
|
|
state_dict=mm_projector_state_dict, |
|
|
) |
|
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
|
|
|
if self.get_speech_mm_projector(): |
|
|
print(f"saving speech_mm_projector to {osp.join(output_dir, 'speech_mm_projector')}") |
|
|
self.speech_mm_projector.config._name_or_path = osp.join(output_dir, "speech_mm_projector").replace( |
|
|
"tmp-checkpoint", "checkpoint" |
|
|
) |
|
|
speech_mm_projector_state_dict = OrderedDict( |
|
|
{k.split("speech_mm_projector.")[-1]: v for k, v in state_dict.items() if "speech_mm_projector" in k} |
|
|
) |
|
|
self.speech_mm_projector.save_pretrained( |
|
|
os.path.join(output_dir, "speech_mm_projector"), |
|
|
state_dict=speech_mm_projector_state_dict, |
|
|
) |
|
|
self.config.speech_mm_projector_cfg = self.speech_mm_projector.config |
|
|
|
|
|
if self.get_sound_mm_projector(): |
|
|
print(f"saving sound_mm_projector to {osp.join(output_dir, 'sound_mm_projector')}") |
|
|
self.sound_mm_projector.config._name_or_path = osp.join(output_dir, "sound_mm_projector").replace( |
|
|
"tmp-checkpoint", "checkpoint" |
|
|
) |
|
|
|
|
|
sound_mm_projector_state_dict = OrderedDict( |
|
|
{k.split("sound_mm_projector.")[-1]: v for k, v in state_dict.items() if "sound_mm_projector" in k} |
|
|
) |
|
|
self.sound_mm_projector.save_pretrained( |
|
|
os.path.join(output_dir, "sound_mm_projector"), |
|
|
state_dict=sound_mm_projector_state_dict, |
|
|
) |
|
|
self.config.sound_mm_projector_cfg = self.sound_mm_projector.config |
|
|
|
|
|
|
|
|
self.config._name_or_path = output_dir |
|
|
self.config.architectures = [self.__class__.__name__] |
|
|
self.config.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
self.copy_remote_py_files(output_dir) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: Optional[str] = None, |
|
|
*model_args, |
|
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
|
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
|
ignore_mismatched_sizes: bool = False, |
|
|
force_download: bool = False, |
|
|
local_files_only: bool = False, |
|
|
token: Optional[Union[str, bool]] = None, |
|
|
revision: str = "main", |
|
|
use_safetensors: Optional[bool] = None, |
|
|
weights_only: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
|
|
if kwargs.get("torch_dtype", None) is not None: |
|
|
config.torch_dtype = kwargs.get("torch_dtype", None) |
|
|
config.model_dtype = kwargs.get("torch_dtype", None) |
|
|
if type(kwargs.get("torch_dtype", None)) == str: |
|
|
kwargs["torch_dtype"] = eval(kwargs.get("torch_dtype", None)) |
|
|
else: |
|
|
kwargs["torch_dtype"] = kwargs.get("torch_dtype", None) |
|
|
return cls._from_config(config, **kwargs) |
|
|
|
|
|
def init_llm(self, llm_config, config, *args, **kwargs): |
|
|
"""Initialize language model and tokenizer.""" |
|
|
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_config, config, *args, **kwargs) |
|
|
|
|
|
self.pad_token_list = ( |
|
|
self.tokenizer.pad_token_id, |
|
|
self.tokenizer.eos_token_id, |
|
|
self.tokenizer.tokenize("<|endoftext|>")[0], |
|
|
) |
|
|
|
|
|
self.vocab_size = len(self.tokenizer) |
|
|
self.update_vocab_size = lambda: setattr(self, "vocab_size", len(self.tokenizer)) |
|
|
|
|
|
|
|
|
self.grammar_compiler = None |
|
|
|
|
|
return self.llm, self.tokenizer |
|
|
|
|
|
def post_config(self): |
|
|
self.training = self.llm.training |
|
|
if self.training: |
|
|
self.train() |
|
|
else: |
|
|
self.eval() |
|
|
|
|
|
|
|
|
if getattr(self.config, "llm_cfg", None) is None: |
|
|
self.config.llm_cfg = self.llm.config |
|
|
if getattr(self.config, "vision_tower_cfg", None) is None: |
|
|
self.config.vision_tower_cfg = self.vision_tower.config |
|
|
if getattr(self.config, "mm_projector_cfg", None) is None: |
|
|
self.config.mm_projector_cfg = self.mm_projector.config |
|
|
if getattr(self.config, "speech_tower_cfg", None) is None and hasattr(self, "speech_tower"): |
|
|
self.config.speech_tower_cfg = self.speech_tower.config |
|
|
if getattr(self.config, "sound_tower_cfg", None) is None and hasattr(self, "sound_tower"): |
|
|
self.config.sound_tower_cfg = self.sound_tower.config |
|
|
if getattr(self.config, "speech_mm_projector_cfg", None) is None and hasattr(self, "speech_mm_projector"): |
|
|
self.config.speech_mm_projector_cfg = self.speech_mm_projector.config |
|
|
if getattr(self.config, "sound_mm_projector_cfg", None) is None and hasattr(self, "sound_mm_projector"): |
|
|
self.config.sound_mm_projector_cfg = self.sound_mm_projector.config |
|
|
|
|
|
def get_llm(self): |
|
|
llm = getattr(self, "llm", None) |
|
|
if type(llm) is list: |
|
|
llm = llm[0] |
|
|
return llm |
|
|
|
|
|
def get_lm_head(self): |
|
|
lm_head = getattr(self.get_llm(), "lm_head", None) |
|
|
return lm_head |
|
|
|
|
|
def get_vision_tower(self): |
|
|
vision_tower = getattr(self, "vision_tower", None) |
|
|
if type(vision_tower) is list: |
|
|
vision_tower = vision_tower[0] |
|
|
return vision_tower |
|
|
|
|
|
def get_speech_tower(self): |
|
|
speech_tower = getattr(self, "speech_tower", None) |
|
|
if type(speech_tower) is list: |
|
|
speech_tower = speech_tower[0] |
|
|
return speech_tower |
|
|
|
|
|
def get_sound_tower(self): |
|
|
sound_tower = getattr(self, "sound_tower", None) |
|
|
if type(sound_tower) is list: |
|
|
sound_tower = sound_tower[0] |
|
|
return sound_tower |
|
|
|
|
|
def get_mm_projector(self): |
|
|
mm_projector = getattr(self, "mm_projector", None) |
|
|
if type(mm_projector) is list: |
|
|
mm_projector = mm_projector[0] |
|
|
return mm_projector |
|
|
|
|
|
def get_sound_mm_projector(self): |
|
|
sound_mm_projector = getattr(self, "sound_mm_projector", None) |
|
|
if type(sound_mm_projector) is list: |
|
|
sound_mm_projector = sound_mm_projector[0] |
|
|
return sound_mm_projector |
|
|
|
|
|
def get_speech_tower(self): |
|
|
speech_tower = getattr(self, "speech_tower", None) |
|
|
if type(speech_tower) is list: |
|
|
speech_tower = speech_tower[0] |
|
|
return speech_tower |
|
|
|
|
|
def get_speech_mm_projector(self): |
|
|
speech_mm_projector = getattr(self, "speech_mm_projector", None) |
|
|
if type(speech_mm_projector) is list: |
|
|
speech_mm_projector = speech_mm_projector[0] |
|
|
return speech_mm_projector |
|
|
|
|
|
def freezed_module_patch(self): |
|
|
""" |
|
|
Huggingface will call model.train() at each training_step. To ensure the expected behaviors for modules like dropout, batchnorm, etc., we need to call model.eval() for the freezed modules. |
|
|
""" |
|
|
if self.training: |
|
|
if self.get_llm() and not getattr(self.config, "tune_language_model", False): |
|
|
pass |
|
|
if self.get_vision_tower() and not getattr(self.config, "tune_vision_tower", False): |
|
|
self.get_vision_tower().eval() |
|
|
if self.get_speech_tower() and not getattr(self.config, "tune_speech_tower", False): |
|
|
self.get_speech_tower().eval() |
|
|
if self.get_sound_tower() and not getattr(self.config, "tune_sound_tower", False): |
|
|
self.get_sound_tower().eval() |
|
|
if self.get_mm_projector() and not getattr(self.config, "tune_mm_projector", False): |
|
|
self.get_mm_projector().eval() |
|
|
if self.get_speech_mm_projector() and not getattr(self.config, "tune_speech_mm_projector", False): |
|
|
self.get_speech_mm_projector().eval() |
|
|
if self.get_sound_mm_projector() and not getattr(self.config, "tune_sound_mm_projector", False): |
|
|
self.get_sound_mm_projector().eval() |
|
|
|
|
|
|
|
|
class VILAForCausalLM(VILAPretrainedModel): |
|
|
def __init__(self, config: VILAConfig, *args, **kwargs): |
|
|
super().__init__(config, *args, **kwargs) |
|
|
|
|
|
def merge_features_for_dynamic_s2(self, image_features, block_sizes): |
|
|
scales = self.get_vision_tower().scales |
|
|
resize_output_to_scale_idx = self.get_vision_tower().resize_output_to_scale_idx |
|
|
|
|
|
image_features_each_image = [] |
|
|
new_block_sizes = [] |
|
|
block_cnt = 0 |
|
|
for block_size_each_image in block_sizes: |
|
|
if block_size_each_image is None: |
|
|
cur_features = image_features[block_cnt : block_cnt + 1] |
|
|
cur_features = rearrange(cur_features, "1 (h w) c -> 1 c h w", h=int(cur_features.shape[1] ** 0.5)) |
|
|
cur_features = cur_features.repeat(1, len(scales), 1, 1) |
|
|
image_features_each_image.append(cur_features) |
|
|
new_block_sizes.append((1, 1)) |
|
|
block_cnt += 1 |
|
|
else: |
|
|
cur_features_each_scale = [] |
|
|
for scale in scales[:-1]: |
|
|
num_blocks_this_scale = (scale // scales[0]) ** 2 |
|
|
cur_features_each_scale.append( |
|
|
self.merge_chessboard( |
|
|
image_features[block_cnt : block_cnt + num_blocks_this_scale], |
|
|
num_split_h=scale // scales[0], |
|
|
num_split_w=scale // scales[0], |
|
|
) |
|
|
) |
|
|
block_cnt += num_blocks_this_scale |
|
|
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1] |
|
|
cur_features_each_scale.append( |
|
|
self.merge_chessboard( |
|
|
image_features[block_cnt : block_cnt + num_blocks_last_scale], |
|
|
num_split_h=block_size_each_image[0], |
|
|
num_split_w=block_size_each_image[1], |
|
|
) |
|
|
) |
|
|
block_cnt += num_blocks_last_scale |
|
|
|
|
|
|
|
|
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:] |
|
|
cur_features = torch.cat( |
|
|
[ |
|
|
F.interpolate(cur_features_each_scale[i].to(torch.float32), size=output_size, mode="area").to( |
|
|
cur_features_each_scale[i].dtype |
|
|
) |
|
|
for i in range(len(cur_features_each_scale)) |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
image_features_each_image.append(cur_features) |
|
|
|
|
|
if resize_output_to_scale_idx == len(scales) - 1 or resize_output_to_scale_idx == -1: |
|
|
new_block_sizes.append(block_size_each_image) |
|
|
else: |
|
|
new_block_sizes.append( |
|
|
( |
|
|
scales[resize_output_to_scale_idx] // scales[0], |
|
|
scales[resize_output_to_scale_idx] // scales[0], |
|
|
) |
|
|
) |
|
|
|
|
|
assert block_cnt == len(image_features) |
|
|
|
|
|
return image_features_each_image, new_block_sizes |
|
|
|
|
|
@staticmethod |
|
|
def split_chessboard(x, num_split_h, num_split_w): |
|
|
""" |
|
|
x: b * c * h * w |
|
|
out: b * c * h * w |
|
|
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension |
|
|
""" |
|
|
B, C, H, W = x.shape |
|
|
assert H % num_split_h == 0 and W % num_split_w == 0 |
|
|
h, w = H // num_split_h, W // num_split_w |
|
|
x_split = torch.cat( |
|
|
[x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w] for i in range(num_split_h) for j in range(num_split_w)], |
|
|
dim=0, |
|
|
) |
|
|
return x_split |
|
|
|
|
|
@staticmethod |
|
|
def merge_chessboard(x, num_split_h, num_split_w): |
|
|
""" |
|
|
x: b * n * c or b * h * w * c |
|
|
out: b * c * h * w |
|
|
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. |
|
|
""" |
|
|
B = x.shape[0] |
|
|
if x.dim() == 3: |
|
|
N = x.shape[1] |
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=int(N**0.5), w=int(N**0.5)) |
|
|
|
|
|
assert B % (num_split_h * num_split_w) == 0 |
|
|
b = B // (num_split_h * num_split_w) |
|
|
|
|
|
x_merge = torch.cat( |
|
|
[ |
|
|
torch.cat( |
|
|
[x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b] for j in range(num_split_w)], dim=-1 |
|
|
) |
|
|
for i in range(num_split_h) |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
|
|
|
return x_merge |
|
|
|
|
|
def encode_video(self, inp, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None): |
|
|
bs = len(inp) |
|
|
cache_feas = [] |
|
|
cache_feas_index = [] |
|
|
inp_block_sizes = block_sizes |
|
|
|
|
|
|
|
|
for _idx in range(len(inp)): |
|
|
if type(inp[_idx]) == CacheFeatures: |
|
|
cache_feas.append(inp[_idx]) |
|
|
cache_feas_index.append(_idx) |
|
|
raw_images = [_ for _ in inp if type(_) != CacheFeatures] |
|
|
|
|
|
raw_videos_num_frames = [_.shape[0] for _ in raw_images] |
|
|
if len(raw_images) > 0: |
|
|
images = torch.cat(raw_images, dim=0) |
|
|
else: |
|
|
images = [] |
|
|
|
|
|
if block_sizes is None: |
|
|
block_sizes = [None] * len(images) |
|
|
|
|
|
def _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames): |
|
|
|
|
|
if len(cache_feas) > 0: |
|
|
if len(image_features) > 0: |
|
|
image_features = torch.split(image_features, raw_videos_num_frames) |
|
|
new_image_features = [] |
|
|
cache_feas_idx = 0 |
|
|
raw_fea_idx = 0 |
|
|
for _idx in range(bs): |
|
|
if _idx in cache_feas_index: |
|
|
new_image_features.append(cache_feas[cache_feas_idx].value['features'].to(self.device, self.dtype)) |
|
|
cache_feas_idx += 1 |
|
|
else: |
|
|
new_image_features.append(image_features[raw_fea_idx]) |
|
|
raw_fea_idx += 1 |
|
|
|
|
|
assert len(new_image_features) == bs |
|
|
image_features = new_image_features |
|
|
image_features = torch.cat(image_features, dim=0) |
|
|
return image_features |
|
|
|
|
|
if getattr(self.config, "dynamic_s2", False): |
|
|
|
|
|
if len(images) > 0: |
|
|
image_features = self.get_vision_tower()(images) |
|
|
|
|
|
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes) |
|
|
|
|
|
image_features = [ |
|
|
self.split_chessboard(x, block_size[0], block_size[1]) |
|
|
for x, block_size in zip(image_features, new_block_sizes) |
|
|
] |
|
|
image_features = torch.cat( |
|
|
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0 |
|
|
) |
|
|
else: |
|
|
image_features = [] |
|
|
|
|
|
|
|
|
image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if inp_block_sizes is None: |
|
|
new_block_sizes = [(1, 1)] * len(image_features) |
|
|
else: |
|
|
raise ValueError(f"inp_block_sizes is not None: {inp_block_sizes}") |
|
|
image_features = image_features.to(self.device, self.dtype) |
|
|
image_features = self.get_mm_projector()(image_features) |
|
|
image_features = list( |
|
|
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0) |
|
|
) |
|
|
image_features = [ |
|
|
self.merge_chessboard(x, block_size[0], block_size[1]) |
|
|
for x, block_size in zip(image_features, new_block_sizes) |
|
|
] |
|
|
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] |
|
|
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]): |
|
|
image_features = torch.stack(image_features, dim=0) |
|
|
else: |
|
|
if len(images) > 0: |
|
|
image_features = self.get_vision_tower()(images) |
|
|
else: |
|
|
image_features = [] |
|
|
|
|
|
|
|
|
image_features = _load_video_features(image_features, cache_feas, cache_feas_index, raw_videos_num_frames) |
|
|
|
|
|
image_features = self.get_mm_projector()(image_features) |
|
|
return image_features |
|
|
|
|
|
def encode_images(self, images, block_sizes: Optional[Optional[Tuple[int, ...]]] = None, mm_info: Optional[dict] = None, num_frames: Optional[List[int]] = None): |
|
|
if block_sizes is None: |
|
|
block_sizes = [None] * len(images) |
|
|
|
|
|
if getattr(self.config, "dynamic_s2", False): |
|
|
image_features = self.get_vision_tower()(images) |
|
|
|
|
|
image_features, new_block_sizes = self.merge_features_for_dynamic_s2(image_features, block_sizes) |
|
|
|
|
|
image_features = [ |
|
|
self.split_chessboard(x, block_size[0], block_size[1]) |
|
|
for x, block_size in zip(image_features, new_block_sizes) |
|
|
] |
|
|
image_features = torch.cat( |
|
|
[rearrange(x, "b c h w -> b (h w) c") for x in image_features], dim=0 |
|
|
) |
|
|
|
|
|
image_features = self.get_mm_projector()(image_features) |
|
|
image_features = list( |
|
|
image_features.split([block_size[0] * block_size[1] for block_size in new_block_sizes], dim=0) |
|
|
) |
|
|
image_features = [ |
|
|
self.merge_chessboard(x, block_size[0], block_size[1]) |
|
|
for x, block_size in zip(image_features, new_block_sizes) |
|
|
] |
|
|
image_features = [rearrange(x, "1 c h w -> (h w) c") for x in image_features] |
|
|
if all([feature.shape[0] == image_features[0].shape[0] for feature in image_features]): |
|
|
image_features = torch.stack(image_features, dim=0) |
|
|
else: |
|
|
image_features = self.get_vision_tower()(images) |
|
|
|
|
|
image_features = self.get_mm_projector()(image_features) |
|
|
return image_features |
|
|
|
|
|
def encode_sound(self, sounds, mm_info: Optional[dict] = None): |
|
|
|
|
|
audio_features, audio_output_lengths = self.get_sound_tower()(sounds) |
|
|
|
|
|
use_fea_downsample = False |
|
|
if getattr(self.config, "sound_mm_projector", "") != "": |
|
|
if "mlp_downsample" in getattr(self.config, "sound_mm_projector", ""): |
|
|
use_fea_downsample = True |
|
|
else: |
|
|
sound_mm_projector_cfg = getattr(self.config, "sound_mm_projector_cfg", None) |
|
|
if sound_mm_projector_cfg is not None: |
|
|
if type(sound_mm_projector_cfg) == dict: |
|
|
if "mlp_downsample" in sound_mm_projector_cfg["sound_mm_projector_type"]: |
|
|
use_fea_downsample = True |
|
|
elif type(sound_mm_projector_cfg) == SoundMultimodalProjectorConfig: |
|
|
if "mlp_downsample" in sound_mm_projector_cfg.sound_mm_projector_type: |
|
|
use_fea_downsample = True |
|
|
|
|
|
if not use_fea_downsample: |
|
|
audio_features = self.get_sound_mm_projector()(audio_features) |
|
|
|
|
|
if audio_output_lengths is not None: |
|
|
|
|
|
new_audio_features = [] |
|
|
start = 0 |
|
|
for length in audio_output_lengths: |
|
|
new_audio_features.append(audio_features[start : start + length]) |
|
|
start += length |
|
|
audio_features = new_audio_features |
|
|
|
|
|
if use_fea_downsample: |
|
|
audio_features = torch.stack(audio_features, dim=0) |
|
|
audio_features = self.get_sound_mm_projector()(audio_features) |
|
|
|
|
|
return audio_features |
|
|
|
|
|
def train(self, mode: bool = True): |
|
|
super().train(mode) |
|
|
return self |
|
|
|
|
|
def _embed( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
media: Dict[str, List[torch.Tensor]], |
|
|
media_config: Dict[str, Dict[str, Any]], |
|
|
labels: Optional[torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
media = copy.deepcopy(media) |
|
|
media_config = copy.deepcopy(media_config) |
|
|
|
|
|
labels = labels if labels is not None else torch.full_like(input_ids, IGNORE_INDEX) |
|
|
attention_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids, dtype=torch.bool) |
|
|
|
|
|
PROCESS_GROUP_MANAGER = None |
|
|
if PROCESS_GROUP_MANAGER is not None: |
|
|
for name in media: |
|
|
self.encoders[name].end_tokens = None |
|
|
|
|
|
|
|
|
text_embeds = self.llm_model_embed_tokens(input_ids) |
|
|
|
|
|
mm_info = {} |
|
|
if "video_info" in media: |
|
|
video_info = media["video_info"] |
|
|
del media["video_info"] |
|
|
mm_info['video_info'] = video_info |
|
|
else: |
|
|
video_info = None |
|
|
|
|
|
if "audio_info" in media: |
|
|
audio_info = media["audio_info"] |
|
|
del media["audio_info"] |
|
|
mm_info['audio_info'] = audio_info |
|
|
else: |
|
|
audio_info = None |
|
|
|
|
|
if media is not None: |
|
|
media_embeds = self.__embed_media_tokens(media, media_config, mm_info) |
|
|
else: |
|
|
|
|
|
media_embeds = {} |
|
|
|
|
|
if PROCESS_GROUP_MANAGER is not None: |
|
|
media_embeds_video = [] |
|
|
for i, images in enumerate(media_embeds["video"]): |
|
|
num_video_frame = media["video"][i].shape[0] |
|
|
media_embeds_video += torch.unbind(images.reshape(num_video_frame, -1, images.shape[-1])) |
|
|
media_embeds["video"] = deque(media_embeds_video) |
|
|
|
|
|
|
|
|
while media_embeds.get("dummy"): |
|
|
dummy_embed = media_embeds["dummy"].popleft() |
|
|
text_embeds += torch.sum(dummy_embed) * 0 |
|
|
|
|
|
|
|
|
video_sound_embeds_idx = 0 |
|
|
sep_embed = self.encoders["video"].embed_tokens("\n") |
|
|
text_embeds = text_embeds.to(self.dtype) |
|
|
sep_embed = sep_embed.to(text_embeds.dtype) |
|
|
|
|
|
if video_info is not None and self.config.load_audio_in_video and self.config.interleaved_vis_aud_in_video: |
|
|
assert self.encoders["video"].end_tokens is None, "end_tokens must be None for interleaved vis-aud in video" |
|
|
new_video_embeds = deque() |
|
|
video_embeds_idx = 0 |
|
|
for k in range(len(video_info)): |
|
|
if video_info[k] is None: |
|
|
continue |
|
|
for i in range(len(video_info[k])): |
|
|
has_audio = video_info[k][i]["has_audio"] |
|
|
if not has_audio: |
|
|
new_video_embeds.append(media_embeds["video"][video_embeds_idx]) |
|
|
video_embeds_idx += 1 |
|
|
continue |
|
|
|
|
|
|
|
|
if video_sound_embeds_idx >= len(media_embeds["sound"]): |
|
|
raise ValueError(f"Sound embeddings index {video_sound_embeds_idx} out of bounds for video_info[{k}][{i}]") |
|
|
|
|
|
segment_aud_indices_list = video_info[k][i]["segment_aud_indices_list"] |
|
|
segment_vis_indices_list = video_info[k][i]["segment_vis_indices_list"] |
|
|
|
|
|
vis_fea_len_per_frame = media_embeds["video"][video_embeds_idx].shape[0] / video_info[k][i]["expected_frame_count"] |
|
|
aud_fea_len_per_stft_frame = media_embeds["sound"][video_sound_embeds_idx].shape[0] / audio_info[k][i]["new_audio_n_stft_frames"] |
|
|
vis_end = 0 |
|
|
aud_end = 0 |
|
|
_new_video_embed = [] |
|
|
for j in range(len(segment_vis_indices_list)): |
|
|
_vis_aud_fea = [] |
|
|
if len(segment_vis_indices_list[j]) > 0: |
|
|
_new_frames = [int(np.ceil((_frame+1) * vis_fea_len_per_frame)) for _frame in segment_vis_indices_list[j]] |
|
|
_vis_fea_end = _new_frames[-1] |
|
|
|
|
|
_vis_fea_end = min(_vis_fea_end, media_embeds["video"][video_embeds_idx].shape[0]) |
|
|
if j == len(segment_vis_indices_list) - 1 and i == len(video_info) - 1 and k == len(video_info[i]) - 1 and not _vis_fea_end == media_embeds["video"][video_embeds_idx].shape[0]: |
|
|
print(f"Warning: The number of last interleaved video features does not match the video feature length. Expected: {media_embeds['video'][video_embeds_idx].shape[0]}, Got: {_vis_fea_end}") |
|
|
_vis_fea_end = media_embeds["video"][video_embeds_idx].shape[0] |
|
|
_vis_fea = media_embeds["video"][video_embeds_idx][vis_end:_vis_fea_end] |
|
|
vis_end = _vis_fea_end |
|
|
_vis_aud_fea.append(_vis_fea) |
|
|
_vis_aud_fea.append(sep_embed) |
|
|
if len(segment_aud_indices_list[j]) > 0: |
|
|
_new_audio_indices = [int(np.ceil(_fea * aud_fea_len_per_stft_frame)) for _fea in segment_aud_indices_list[j]] |
|
|
_aud_fea_end = _new_audio_indices[-1] |
|
|
|
|
|
_aud_fea_end = min(_aud_fea_end, media_embeds["sound"][video_sound_embeds_idx].shape[0]) |
|
|
_aud_fea = media_embeds["sound"][video_sound_embeds_idx][aud_end:_aud_fea_end] |
|
|
_vis_aud_fea.append(_aud_fea) |
|
|
aud_end = _aud_fea_end |
|
|
_vis_aud_fea.append(sep_embed) |
|
|
_new_video_embed.append(torch.cat(_vis_aud_fea, dim=0)) |
|
|
video_sound_embeds_idx += 1 |
|
|
new_video_embeds.append(torch.cat(_new_video_embed, dim=0)) |
|
|
video_embeds_idx += 1 |
|
|
|
|
|
assert len(new_video_embeds) == len(media_embeds["video"]), "The number of new video embeddings does not match the number of original video embeddings." |
|
|
media_embeds["video"] = new_video_embeds |
|
|
|
|
|
batch_size = labels.shape[0] |
|
|
text_embeds = [text_embeds[k][attention_mask[k]] for k in range(batch_size)] |
|
|
labels = [labels[k][attention_mask[k]] for k in range(batch_size)] |
|
|
|
|
|
media_tokens = {} |
|
|
for name, token_id in self.tokenizer.media_token_ids.items(): |
|
|
media_tokens[token_id] = name |
|
|
|
|
|
|
|
|
inputs_m, labels_m = [], [] |
|
|
sound_embeds_idx = 0 |
|
|
for k in range(batch_size): |
|
|
inputs_mk, labels_mk = [], [] |
|
|
pos = 0 |
|
|
while pos < len(labels[k]): |
|
|
if input_ids[k][pos].item() in media_tokens: |
|
|
name = media_tokens[input_ids[k][pos].item()] if PROCESS_GROUP_MANAGER is None else "video" |
|
|
if input_ids[k][pos].item() == self.tokenizer.media_token_ids["sound"]: |
|
|
if self.config.interleaved_vis_aud_in_video: |
|
|
if sound_embeds_idx < video_sound_embeds_idx: |
|
|
media_embeds[name].popleft() |
|
|
sound_embeds_idx += 1 |
|
|
pos += 1 |
|
|
continue |
|
|
sound_embeds_idx += 1 |
|
|
|
|
|
end = pos + 1 |
|
|
input = media_embeds[name].popleft() |
|
|
label = torch.full([input.shape[0]], IGNORE_INDEX, device=labels[k].device, dtype=labels[k].dtype) |
|
|
else: |
|
|
end = pos |
|
|
while end < len(labels[k]) and input_ids[k][end].item() not in media_tokens: |
|
|
end += 1 |
|
|
input = text_embeds[k][pos:end] |
|
|
label = labels[k][pos:end] |
|
|
|
|
|
inputs_mk.append(input) |
|
|
labels_mk.append(label) |
|
|
pos = end |
|
|
inputs_m.append(torch.cat(inputs_mk, dim=0)) |
|
|
labels_m.append(torch.cat(labels_mk, dim=0)) |
|
|
inputs, labels = inputs_m, labels_m |
|
|
|
|
|
inputs[0] += sep_embed.mean() * 0 |
|
|
|
|
|
|
|
|
for name in media_embeds: |
|
|
if media_embeds[name]: |
|
|
raise ValueError(f"Not all {name} embeddings are consumed! Still {len(media_embeds[name])} left.") |
|
|
|
|
|
|
|
|
inputs, labels = self.__truncate_sequence(inputs, labels) |
|
|
|
|
|
|
|
|
return self.__batchify_sequence(inputs, labels) |
|
|
|
|
|
def __embed_media_tokens( |
|
|
self, |
|
|
media: Dict[str, List[torch.Tensor]], |
|
|
media_config: Dict[str, Dict[str, Any]], |
|
|
mm_info, |
|
|
) -> Dict[str, List[torch.Tensor]]: |
|
|
embeds = defaultdict(deque) |
|
|
|
|
|
if self.config.unified_audio_encoder: |
|
|
assert len(media["speech"]) == 0 |
|
|
|
|
|
for name in media: |
|
|
_encoder = self.encoders[name] |
|
|
if name in ["speech", "sound"] and self.config.unified_audio_encoder: |
|
|
_encoder = self.encoders["sound"] |
|
|
|
|
|
if self.training: |
|
|
|
|
|
if name in ["speech", "sound"]: |
|
|
|
|
|
info = [] |
|
|
if type(media.get(name, {})) is dict: |
|
|
for _dict in media.get(name, {}): |
|
|
info.append({k: {"shape": v.shape, "dtype": v.dtype} for k, v in _dict.items()}) |
|
|
elif type(media.get(name, {})) is list: |
|
|
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])] |
|
|
else: |
|
|
raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}") |
|
|
|
|
|
infos_list = vila_all_gather(info) |
|
|
infos = list(chain(*infos_list)) |
|
|
|
|
|
|
|
|
if not infos: |
|
|
continue |
|
|
|
|
|
|
|
|
max_batch_size = max(len(_info) for _info in infos_list) |
|
|
missing_batch_size = max_batch_size - len(info) |
|
|
|
|
|
_media = media.get(name, []) |
|
|
|
|
|
_medias = list(chain(vila_all_gather(_media))) |
|
|
if missing_batch_size > 0: |
|
|
for i in range(missing_batch_size): |
|
|
|
|
|
if type(media.get(name, {})) is dict: |
|
|
_dummy = {k: v.clone().to(device=self.device) for k, v in _medias[0].items()} |
|
|
elif type(media.get(name, {})) is list: |
|
|
if type(_medias[0]) is torch.Tensor: |
|
|
_dummy = _medias[0].clone().to(device=self.device) |
|
|
elif type(_medias[0]) is np.ndarray: |
|
|
_dummy = _medias[0].copy() |
|
|
else: |
|
|
raise ValueError(f"Unsupported media type: {type(_medias[0])}") |
|
|
else: |
|
|
raise ValueError(f"Unsupported media type: {type(media.get(name, {}))}") |
|
|
_media.append(_dummy) |
|
|
mm_info["audio_info"].append(["dummy"]) |
|
|
|
|
|
|
|
|
cur_batch_max_audio_samples = max(len(_audio) for _audio in _medias) |
|
|
cur_batch_max_audio_samples = int(np.ceil(cur_batch_max_audio_samples / (self.config.audio_sampling_rate * 30)) * (self.config.audio_sampling_rate * 30)) |
|
|
cur_batch_max_audio_samples = min(cur_batch_max_audio_samples, self.config.audio_chunk_length * self.config.audio_sampling_rate) |
|
|
cur_batch_max_audio_duration = cur_batch_max_audio_samples // self.config.audio_sampling_rate |
|
|
|
|
|
|
|
|
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained( |
|
|
self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length |
|
|
) |
|
|
|
|
|
|
|
|
new_media = [] |
|
|
|
|
|
aud_idx = 0 |
|
|
for _batch_idx in range(len(mm_info["audio_info"])): |
|
|
_audio_info = mm_info["audio_info"][_batch_idx] |
|
|
if _audio_info is not None: |
|
|
for _mm_idx in range(len(_audio_info)): |
|
|
_audio = _media[aud_idx] |
|
|
if type(_audio) is torch.Tensor: |
|
|
device = _audio.device |
|
|
dtype = _audio.dtype |
|
|
_audio = _audio.cpu().float() |
|
|
else: |
|
|
|
|
|
device = self.device |
|
|
dtype = self.dtype |
|
|
_audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples) |
|
|
aud_idx += 1 |
|
|
stft_features = whisper_feature_extractor( |
|
|
_audio, |
|
|
sampling_rate=self.config.audio_sampling_rate, |
|
|
return_attention_mask=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
).to(device, dtype) |
|
|
new_media.append(stft_features) |
|
|
if _audio_info[_mm_idx] != "dummy": |
|
|
_audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration |
|
|
_audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples |
|
|
_audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration |
|
|
_audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1] |
|
|
|
|
|
assert aud_idx == len(_media), "The number of audio info does not match the number of audio samples." |
|
|
_media = new_media |
|
|
|
|
|
_fea = _encoder(_media, media_config[name], mm_info) |
|
|
|
|
|
|
|
|
_dummy_fea = _fea[len(info) :] |
|
|
embeds["dummy"].extend(_dummy_fea) |
|
|
|
|
|
|
|
|
_real_fea = _fea[: len(info)] |
|
|
if len(info) > 0: |
|
|
embeds[name] = deque(_real_fea) |
|
|
|
|
|
else: |
|
|
|
|
|
info = [{"shape": tensor.shape, "dtype": tensor.dtype} for tensor in media.get(name, [])] |
|
|
infos = list(chain(vila_all_gather(info))) |
|
|
|
|
|
|
|
|
if not infos: |
|
|
continue |
|
|
|
|
|
|
|
|
if media.get(name) is None or len(media[name]) == 0: |
|
|
dummy = torch.zeros(infos[0]["shape"], dtype=infos[0]["dtype"], device=self.device) |
|
|
embeds["dummy"].extend(self.encoders[name]([dummy], media_config[name])) |
|
|
continue |
|
|
embeds[name] = deque(self.encoders[name](media[name], media_config[name])) |
|
|
|
|
|
else: |
|
|
if name == "sound": |
|
|
all_audio_chunk_lengths = [] |
|
|
for _sample_idx in range(len(media[name])): |
|
|
for _mm_idx in range(len(mm_info["audio_info"][_sample_idx])): |
|
|
_new_audio_chunk_length = mm_info["audio_info"][_sample_idx][_mm_idx]["new_audio_chunk_length"] |
|
|
all_audio_chunk_lengths.append(_new_audio_chunk_length) |
|
|
cur_batch_max_audio_duration = max(all_audio_chunk_lengths) |
|
|
cur_batch_max_audio_samples = cur_batch_max_audio_duration * self.config.audio_sampling_rate |
|
|
|
|
|
|
|
|
|
|
|
whisper_feature_extractor = WhisperFeatureExtractor.from_pretrained( |
|
|
self.config._name_or_path, chunk_length=cur_batch_max_audio_duration, sampling_rate=self.config.audio_sampling_rate, hop_length=self.config.audio_hop_length |
|
|
) |
|
|
|
|
|
new_media = [] |
|
|
_idx = 0 |
|
|
assert len(all_audio_chunk_lengths) == len(media[name]), "The number of audio chunk lengths does not match the number of audio samples." |
|
|
|
|
|
_media = media.get(name, []) |
|
|
aud_idx = 0 |
|
|
for _batch_idx in range(len(mm_info["audio_info"])): |
|
|
_audio_info = mm_info["audio_info"][_batch_idx] |
|
|
if _audio_info is not None: |
|
|
for _mm_idx in range(len(_audio_info)): |
|
|
_audio = _media[aud_idx] |
|
|
if type(_audio) is torch.Tensor: |
|
|
device = _audio.device |
|
|
dtype = _audio.dtype |
|
|
_audio = _audio.cpu().float() |
|
|
else: |
|
|
device = self.device |
|
|
dtype = self.dtype |
|
|
_audio = whisper.pad_or_trim(_audio, length=cur_batch_max_audio_samples) |
|
|
aud_idx += 1 |
|
|
stft_features = whisper_feature_extractor( |
|
|
_audio, |
|
|
sampling_rate=self.config.audio_sampling_rate, |
|
|
return_attention_mask=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
).to(device, dtype) |
|
|
|
|
|
new_media.append(stft_features) |
|
|
if _audio_info[_mm_idx] != "dummy": |
|
|
_audio_info[_mm_idx]["new_audio_chunk_length"] = cur_batch_max_audio_duration |
|
|
_audio_info[_mm_idx]["new_audio_n_samples"] = cur_batch_max_audio_samples |
|
|
_audio_info[_mm_idx]["audio_end_sample_sec"] = _audio_info[_mm_idx]["audio_start_sec"] + cur_batch_max_audio_duration |
|
|
_audio_info[_mm_idx]["new_audio_n_stft_frames"] = stft_features["input_features"].shape[-1] |
|
|
media[name] = new_media |
|
|
|
|
|
if len(media[name]) > 0: |
|
|
embeds[name] = deque(_encoder(media[name], media_config[name], mm_info)) |
|
|
return embeds |
|
|
|
|
|
def __truncate_sequence( |
|
|
self, inputs: List[torch.Tensor], labels: List[torch.Tensor] |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if self.training and any(len(input) > self.tokenizer.model_max_length for input in inputs): |
|
|
warnings.warn(f"Truncating sequences to `model_max_length` ({self.tokenizer.model_max_length}).") |
|
|
inputs = [input[: self.tokenizer.model_max_length] for input in inputs] |
|
|
labels = [label[: self.tokenizer.model_max_length] for label in labels] |
|
|
return inputs, labels |
|
|
|
|
|
def __batchify_sequence( |
|
|
self, inputs: List[torch.Tensor], labels: List[torch.Tensor] |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
batch_size = len(inputs) |
|
|
device = inputs[0].device |
|
|
hidden_size = inputs[0].shape[1] |
|
|
max_length = max(inputs[k].shape[0] for k in range(batch_size)) |
|
|
attention_mask = torch.ones((batch_size, max_length), dtype=torch.bool, device=device) |
|
|
|
|
|
inputs_p, labels_p = [], [] |
|
|
for k in range(batch_size): |
|
|
size_pk = max_length - inputs[k].shape[0] |
|
|
inputs_pk = torch.zeros((size_pk, hidden_size), dtype=inputs[k].dtype, device=device) |
|
|
labels_pk = torch.full((size_pk,), IGNORE_INDEX, dtype=labels[k].dtype, device=device) |
|
|
if self.tokenizer.padding_side == "right": |
|
|
attention_mask[k, inputs[k].shape[0] :] = False |
|
|
inputs_pk = torch.cat([inputs[k], inputs_pk], dim=0) |
|
|
labels_pk = torch.cat([labels[k], labels_pk], dim=0) |
|
|
else: |
|
|
labels[k] = labels[k].to(device) |
|
|
attention_mask[k, : -inputs[k].shape[0]] = False |
|
|
inputs_pk = torch.cat([inputs_pk, inputs[k]], dim=0) |
|
|
labels_pk = torch.cat([labels_pk, labels[k]], dim=0) |
|
|
inputs_p.append(inputs_pk) |
|
|
labels_p.append(labels_pk) |
|
|
|
|
|
inputs = torch.stack(inputs_p, dim=0) |
|
|
labels = torch.stack(labels_p, dim=0) |
|
|
return inputs, labels, attention_mask |
|
|
|
|
|
def repack_multimodal_data(self, inputs_embeds, attention_mask, position_ids, labels): |
|
|
|
|
|
PROCESS_GROUP_MANAGER = None |
|
|
|
|
|
|
|
|
if PROCESS_GROUP_MANAGER is not None: |
|
|
sp_degree = PROCESS_GROUP_MANAGER.sp_degree |
|
|
sp_rank = PROCESS_GROUP_MANAGER.sp_rank |
|
|
sp_group = PROCESS_GROUP_MANAGER.sp_pg |
|
|
ring_degree = PROCESS_GROUP_MANAGER.ring_degree |
|
|
ring_rank = PROCESS_GROUP_MANAGER.ring_rank |
|
|
ring_type = PROCESS_GROUP_MANAGER.ring_type |
|
|
ulysses_degree = PROCESS_GROUP_MANAGER.ulysses_degree |
|
|
ulysses_rank = PROCESS_GROUP_MANAGER.ulysses_rank |
|
|
|
|
|
bs, shard_seqlen = position_ids.shape |
|
|
sp_seq_len = [torch.zeros(1, dtype=torch.int64, device=position_ids.device) for _ in range(sp_degree)] |
|
|
dist.all_gather(sp_seq_len, torch.tensor(shard_seqlen, device=position_ids.device), group=sp_group) |
|
|
sp_seq_len_cat = torch.cat(sp_seq_len, dim=0) |
|
|
|
|
|
if sp_rank == 0: |
|
|
original_start_id = 0 |
|
|
else: |
|
|
original_start_id = torch.sum(sp_seq_len_cat[:sp_rank]).item() |
|
|
original_end_id = torch.sum(sp_seq_len_cat[: sp_rank + 1]).item() |
|
|
|
|
|
|
|
|
all_inputs_embeds = torch.zeros( |
|
|
bs, |
|
|
torch.sum(sp_seq_len_cat), |
|
|
inputs_embeds.shape[-1], |
|
|
dtype=inputs_embeds.dtype, |
|
|
device=inputs_embeds.device, |
|
|
).contiguous() |
|
|
all_inputs_embeds[:, original_start_id:original_end_id, :] += inputs_embeds |
|
|
dist.barrier(group=sp_group) |
|
|
dist.all_reduce(all_inputs_embeds, group=sp_group) |
|
|
dist.barrier(group=sp_group) |
|
|
|
|
|
attention_mask_list = [ |
|
|
torch.zeros((bs, sp_seq_len[i]), dtype=attention_mask.dtype, device=attention_mask.device) |
|
|
for i in range(sp_degree) |
|
|
] |
|
|
position_ids_list = [ |
|
|
torch.zeros((bs, sp_seq_len[i]), dtype=position_ids.dtype, device=position_ids.device) |
|
|
for i in range(sp_degree) |
|
|
] |
|
|
labels_list = [ |
|
|
torch.zeros((bs, sp_seq_len[i]), dtype=labels.dtype, device=labels.device) for i in range(sp_degree) |
|
|
] |
|
|
|
|
|
dist.all_gather(attention_mask_list, attention_mask, group=sp_group) |
|
|
dist.all_gather(position_ids_list, position_ids, group=sp_group) |
|
|
dist.all_gather(labels_list, labels, group=sp_group) |
|
|
|
|
|
effective_seqlen_list = [attention_mask_list[i].sum(dim=-1) for i in range(sp_degree)] |
|
|
effective_seqlen = torch.stack(effective_seqlen_list, dim=-1) |
|
|
effective_seqlen_batch_list = torch.unbind(effective_seqlen, dim=0) |
|
|
|
|
|
global_attention_mask_list = [] |
|
|
global_position_ids_list = [] |
|
|
global_labels_list = [] |
|
|
global_inputs_embeds_list = [] |
|
|
for i in range(bs): |
|
|
global_attention_mask_batch_list = [] |
|
|
global_position_ids_batch_list = [] |
|
|
global_labels_batch_list = [] |
|
|
global_inputs_embeds_batch_list = [] |
|
|
for j in range(sp_degree): |
|
|
eff_len = effective_seqlen_batch_list[i][j] |
|
|
prev_len = torch.sum(sp_seq_len_cat[:j]).item() if j > 0 else 0 |
|
|
|
|
|
global_attention_mask_batch_list.append(attention_mask_list[j][i, :eff_len]) |
|
|
global_position_ids_batch_list.append(position_ids_list[j][i, :eff_len]) |
|
|
global_labels_batch_list.append(labels_list[j][i, :eff_len]) |
|
|
global_inputs_embeds_batch_list.append(all_inputs_embeds[i, prev_len : prev_len + eff_len, :]) |
|
|
global_attention_mask_list.append(torch.cat(global_attention_mask_batch_list, dim=0)) |
|
|
global_position_ids_list.append(torch.cat(global_position_ids_batch_list, dim=0)) |
|
|
global_labels_list.append(torch.cat(global_labels_batch_list, dim=0)) |
|
|
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0)) |
|
|
|
|
|
global_attention_mask = torch.nn.utils.rnn.pad_sequence( |
|
|
global_attention_mask_list, batch_first=True, padding_value=False |
|
|
) |
|
|
global_position_ids = torch.nn.utils.rnn.pad_sequence( |
|
|
global_position_ids_list, batch_first=True, padding_value=-1 |
|
|
) |
|
|
global_labels = torch.nn.utils.rnn.pad_sequence( |
|
|
global_labels_list, batch_first=True, padding_value=IGNORE_INDEX |
|
|
) |
|
|
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence( |
|
|
global_inputs_embeds_list, batch_first=True, padding_value=0 |
|
|
) |
|
|
|
|
|
|
|
|
if ring_degree > 1: |
|
|
total_effective_seqlen = torch.sum(effective_seqlen, dim=1) |
|
|
new_seqlen_per_rank = total_effective_seqlen // sp_degree |
|
|
assert torch.all( |
|
|
total_effective_seqlen % sp_degree == 0 |
|
|
), "total_effective_seqlen must be divisible by sp_degree" |
|
|
|
|
|
max_new_seqlen = torch.max(new_seqlen_per_rank).item() |
|
|
|
|
|
new_attention_mask = torch.zeros( |
|
|
(bs, max_new_seqlen), dtype=global_attention_mask.dtype, device=global_attention_mask.device |
|
|
) |
|
|
new_position_ids = torch.zeros( |
|
|
(bs, max_new_seqlen), dtype=global_position_ids.dtype, device=global_position_ids.device |
|
|
) |
|
|
new_labels = torch.full( |
|
|
(bs, max_new_seqlen), IGNORE_INDEX, dtype=global_labels.dtype, device=global_labels.device |
|
|
) |
|
|
new_inputs_embeds = torch.zeros( |
|
|
(bs, max_new_seqlen, global_inputs_embeds.shape[-1]), |
|
|
dtype=global_inputs_embeds.dtype, |
|
|
device=global_inputs_embeds.device, |
|
|
) |
|
|
|
|
|
if ring_type == "ring_varlen": |
|
|
for i in range(bs): |
|
|
start_idx = new_seqlen_per_rank[i] * sp_rank |
|
|
end_idx = start_idx + new_seqlen_per_rank[i] |
|
|
new_attention_mask[i, : new_seqlen_per_rank[i]] = global_attention_mask[i, start_idx:end_idx] |
|
|
new_position_ids[i, : new_seqlen_per_rank[i]] = global_position_ids[i, start_idx:end_idx] |
|
|
new_labels[i, : new_seqlen_per_rank[i]] = global_labels[i, start_idx:end_idx] |
|
|
new_inputs_embeds[i, : new_seqlen_per_rank[i], :] = global_inputs_embeds[ |
|
|
i, start_idx:end_idx, : |
|
|
] |
|
|
elif ring_type == "zigzag_ring_varlen": |
|
|
chunk_size = total_effective_seqlen // (2 * sp_degree) |
|
|
for i in range(bs): |
|
|
|
|
|
if sp_degree == ring_degree: |
|
|
forward_rank_idx = sp_rank |
|
|
backward_rank_idx = 2 * sp_degree - sp_rank - 1 |
|
|
else: |
|
|
ulysses_offset = ulysses_rank * ring_degree * 2 |
|
|
forward_rank_idx = ring_rank + ulysses_offset |
|
|
backward_rank_idx = sp_degree - ring_rank - 1 + ulysses_offset |
|
|
|
|
|
|
|
|
start_idx_fwd = forward_rank_idx * chunk_size[i] |
|
|
end_idx_fwd = start_idx_fwd + chunk_size[i] |
|
|
|
|
|
start_idx_bwd = backward_rank_idx * chunk_size[i] |
|
|
end_idx_bwd = start_idx_bwd + chunk_size[i] |
|
|
|
|
|
|
|
|
new_attention_mask[i, : chunk_size[i]] = global_attention_mask[i, start_idx_fwd:end_idx_fwd] |
|
|
new_attention_mask[i, chunk_size[i] : 2 * chunk_size[i]] = global_attention_mask[ |
|
|
i, start_idx_bwd:end_idx_bwd |
|
|
] |
|
|
|
|
|
new_position_ids[i, : chunk_size[i]] = global_position_ids[i, start_idx_fwd:end_idx_fwd] |
|
|
new_position_ids[i, chunk_size[i] : 2 * chunk_size[i]] = global_position_ids[ |
|
|
i, start_idx_bwd:end_idx_bwd |
|
|
] |
|
|
|
|
|
new_labels[i, : chunk_size[i]] = global_labels[i, start_idx_fwd:end_idx_fwd] |
|
|
new_labels[i, chunk_size[i] : 2 * chunk_size[i]] = global_labels[i, start_idx_bwd:end_idx_bwd] |
|
|
|
|
|
new_inputs_embeds[i, : chunk_size[i], :] = global_inputs_embeds[i, start_idx_fwd:end_idx_fwd, :] |
|
|
new_inputs_embeds[i, chunk_size[i] : 2 * chunk_size[i], :] = global_inputs_embeds[ |
|
|
i, start_idx_bwd:end_idx_bwd, : |
|
|
] |
|
|
else: |
|
|
raise ValueError(f"Invalid ring_type: {ring_type}") |
|
|
else: |
|
|
global_seq_len = global_attention_mask.shape[-1] |
|
|
seq_len_sharded = global_seq_len // sp_degree |
|
|
start_idx_reshard = seq_len_sharded * sp_rank |
|
|
end_idx_reshard = start_idx_reshard + seq_len_sharded if sp_rank < sp_degree - 1 else global_seq_len |
|
|
|
|
|
new_attention_mask = torch.narrow( |
|
|
global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
|
) |
|
|
new_position_ids = torch.narrow( |
|
|
global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
|
) |
|
|
new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard) |
|
|
new_inputs_embeds = torch.narrow( |
|
|
global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard |
|
|
) |
|
|
|
|
|
return new_inputs_embeds, new_attention_mask, new_position_ids, new_labels |
|
|
|
|
|
device = inputs_embeds.device |
|
|
batch_size = inputs_embeds.shape[0] |
|
|
seqlens = [attention_mask[k].sum().item() for k in range(batch_size)] |
|
|
|
|
|
|
|
|
inputs_embeds_p = [inputs_embeds[k][attention_mask[k]] for k in range(batch_size)] |
|
|
attention_mask_p = [torch.ones(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)] |
|
|
position_ids_p = [torch.arange(seqlens[k], dtype=torch.int, device=device) for k in range(batch_size)] |
|
|
labels_p = [labels[k][attention_mask[k]] for k in range(batch_size)] |
|
|
|
|
|
|
|
|
inputs_embeds_p.append(torch.zeros(1, inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=device)) |
|
|
attention_mask_p.append(torch.tensor([0], dtype=torch.int, device=device)) |
|
|
position_ids_p.append(torch.tensor([0], dtype=torch.int, device=device)) |
|
|
labels_p.append(torch.tensor([IGNORE_INDEX], dtype=torch.int, device=device)) |
|
|
|
|
|
|
|
|
for label in labels_p: |
|
|
label[0] = IGNORE_INDEX |
|
|
|
|
|
|
|
|
inputs_embeds_p = torch.cat(inputs_embeds_p, dim=0).unsqueeze(0) |
|
|
attention_mask_p = torch.cat(attention_mask_p, dim=0).unsqueeze(0) |
|
|
position_ids_p = torch.cat(position_ids_p, dim=0).unsqueeze(0) |
|
|
labels_p = torch.cat(labels_p, dim=0).unsqueeze(0) |
|
|
|
|
|
if hasattr( |
|
|
self, "pad_to_multiple_of" |
|
|
): |
|
|
assert len(labels_p.shape) == 2 |
|
|
batch_size, max_length, cur_length = labels_p.shape[0], labels_p.shape[1], labels_p.shape[1] |
|
|
hidden_size = inputs_embeds_p.shape[-1] |
|
|
|
|
|
if max_length % self.pad_to_multiple_of != 0: |
|
|
max_length = ((max_length // self.pad_to_multiple_of) + 1) * self.pad_to_multiple_of |
|
|
difference = max_length - cur_length |
|
|
|
|
|
inputs_embeds_p = torch.cat( |
|
|
( |
|
|
inputs_embeds_p, |
|
|
torch.full((batch_size, difference, hidden_size), self.llm.pad_token_id).to(inputs_embeds_p), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
labels_p = torch.cat((labels_p, torch.full((batch_size, difference), IGNORE_INDEX).to(labels_p)), dim=1) |
|
|
attention_mask_p = torch.cat( |
|
|
( |
|
|
attention_mask_p, |
|
|
torch.zeros((batch_size, difference), dtype=torch.bool).to(attention_mask_p), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
position_ids_p = torch.cat( |
|
|
(position_ids_p, torch.full((batch_size, difference), -1).to(position_ids_p)), dim=1 |
|
|
) |
|
|
|
|
|
return inputs_embeds_p, attention_mask_p, position_ids_p, labels_p |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
media: Optional[Dict[str, List[torch.Tensor]]] = None, |
|
|
images: Optional[torch.FloatTensor] = None, |
|
|
media_config: Optional[List] = None, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
packing: bool = True, |
|
|
force_packing: bool = False, |
|
|
seqlens_in_batch: Optional[torch.LongTensor] = None, |
|
|
dpo_forward: bool = False, |
|
|
**kwargs, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
self.freezed_module_patch() |
|
|
|
|
|
if images is not None: |
|
|
if media is not None: |
|
|
raise ValueError("Both 'media' and 'images' are provided. Please provide only one.") |
|
|
print("The 'images' argument is deprecated. Please use 'media' instead.") |
|
|
media = {"image": images} |
|
|
|
|
|
if media_config is None: |
|
|
media_config = defaultdict(dict) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds, labels, attention_mask = self._embed(input_ids, media, media_config, labels, attention_mask) |
|
|
|
|
|
if force_packing or (packing and self.training and not dpo_forward): |
|
|
if seqlens_in_batch is None: |
|
|
seqlens_in_batch = torch.sum(attention_mask, dim=1) |
|
|
set_seqlens_in_batch(seqlens_in_batch) |
|
|
|
|
|
(inputs_embeds, attention_mask, position_ids, labels) = self.repack_multimodal_data( |
|
|
inputs_embeds, attention_mask, position_ids, labels |
|
|
) |
|
|
|
|
|
outputs = self.llm( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_key_values=past_key_values, |
|
|
labels=labels, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
if self.training and getattr(self.config, "time_token_ids", []): |
|
|
outputs.loss = soft_cross_entropy( |
|
|
outputs.logits, |
|
|
labels, |
|
|
soft_tokens=self.config.time_token_ids, |
|
|
std=self.config.soft_ce_std, |
|
|
) |
|
|
|
|
|
if dpo_forward: |
|
|
return outputs.logits, labels |
|
|
|
|
|
return outputs |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate( |
|
|
self, |
|
|
input_ids: Optional[torch.FloatTensor] = None, |
|
|
media: Optional[Dict[str, List[torch.Tensor]]] = None, |
|
|
media_config: Dict[str, Dict[str, Any]] = None, |
|
|
attention_mask: Optional[torch.LongTensor] = None, |
|
|
return_output_ids_only: bool = True, |
|
|
**generation_kwargs, |
|
|
) -> torch.LongTensor: |
|
|
""" |
|
|
input_tokens: <image> describe the image |
|
|
media: [Tensor(1, 3, 384, 384), ] |
|
|
-----------> |
|
|
input_tokens: 36000 001 002 003 004 |
|
|
input_emds: <media emd> 001 002 003 004 |
|
|
""" |
|
|
inputs_embeds, _, attention_mask = self._embed(input_ids, media, media_config, None, attention_mask) |
|
|
|
|
|
output_ids = self.llm.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, **generation_kwargs) |
|
|
|
|
|
if return_output_ids_only: |
|
|
return_value = output_ids |
|
|
else: |
|
|
|
|
|
generation_config = generation_kwargs.get("generation_config", None) |
|
|
if generation_config is not None: |
|
|
num_generations = generation_config.num_return_sequences |
|
|
repeat_input_ids = input_ids.repeat_interleave(num_generations, dim=0) |
|
|
return_value = torch.cat([repeat_input_ids, output_ids], dim=-1) |
|
|
else: |
|
|
return_value = torch.cat([input_ids, output_ids], dim=-1) |
|
|
|
|
|
return return_value |
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate_content( |
|
|
self, |
|
|
prompt: Union[str, List], |
|
|
generation_config: Optional[GenerationConfig] = None, |
|
|
response_format=None, |
|
|
) -> str: |
|
|
conversation = [{"from": "human", "value": prompt}] |
|
|
|
|
|
|
|
|
xgr_logits_processor = None |
|
|
|
|
|
|
|
|
|
|
|
media = extract_media(conversation, self.config) |
|
|
|
|
|
|
|
|
media_config = defaultdict(dict) |
|
|
for name in media: |
|
|
if name == "image": |
|
|
if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]: |
|
|
self.config.image_processor = self.vision_tower.image_processor |
|
|
if self.config.image_aspect_ratio == "dynamic": |
|
|
images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half() |
|
|
conversation[0]["value"] = conversation[0]["value"].replace( |
|
|
DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] |
|
|
) |
|
|
else: |
|
|
if type(self.config.s2_scales) is str: |
|
|
self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) |
|
|
images, block_sizes = process_image( |
|
|
media["image"][0], self.config, None, enable_dynamic_s2=True |
|
|
) |
|
|
images = images.half() |
|
|
media_config[name]["block_sizes"] = [block_sizes] |
|
|
else: |
|
|
images = process_images(media["image"], self.vision_tower.image_processor, self.config).half() |
|
|
media[name] = [image for image in images] |
|
|
elif name == "video": |
|
|
if self.config.image_aspect_ratio == "dynamic" and self.config.video_max_tiles > 1: |
|
|
media[name] = [ |
|
|
process_images( |
|
|
images, |
|
|
self.vision_tower.image_processor, |
|
|
self.config, |
|
|
enable_dynamic_res=True, |
|
|
max_tiles=self.config.video_max_tiles, |
|
|
).half() |
|
|
for images in media[name] |
|
|
] |
|
|
elif self.config.image_aspect_ratio == "dynamic_s2" and self.config.video_max_tiles > 1: |
|
|
self.config.image_processor = self.vision_tower.image_processor |
|
|
if type(self.config.s2_scales) is str: |
|
|
self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) |
|
|
media[name] = [ |
|
|
torch.cat( |
|
|
[ |
|
|
process_image( |
|
|
image, |
|
|
self.config, |
|
|
None, |
|
|
enable_dynamic_s2=True, |
|
|
max_tiles=self.config.video_max_tiles, |
|
|
)[0].half() |
|
|
for image in images |
|
|
] |
|
|
) |
|
|
for images in media[name] |
|
|
] |
|
|
else: |
|
|
media[name] = [ |
|
|
process_images(images, self.vision_tower.image_processor, self.config) |
|
|
for images in media[name] |
|
|
] |
|
|
elif name == "speech": |
|
|
speeches = media["speech"] |
|
|
media[name] = [speech for speech in speeches] |
|
|
elif name == "sound": |
|
|
|
|
|
sounds = media["sound"] |
|
|
|
|
|
for sound in sounds: |
|
|
if type(sound) is dict: |
|
|
for k, v in sound.items(): |
|
|
sound[k] = v.half() |
|
|
media[name] = [sound for sound in sounds] |
|
|
elif name == "video_info": |
|
|
media[name] = [media["video_info"]] |
|
|
elif name == "audio_info": |
|
|
media[name] = [media["audio_info"]] |
|
|
else: |
|
|
raise ValueError(f"Unsupported media type: {name}") |
|
|
|
|
|
|
|
|
input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).unsqueeze(0).cuda() |
|
|
|
|
|
|
|
|
generation_config = generation_config or self.default_generation_config |
|
|
|
|
|
|
|
|
try: |
|
|
output_ids = self.generate( |
|
|
input_ids=input_ids, |
|
|
media=media, |
|
|
media_config=media_config, |
|
|
generation_config=generation_config, |
|
|
logits_processor=xgr_logits_processor, |
|
|
) |
|
|
except ValueError: |
|
|
if not generation_config.do_sample: |
|
|
raise |
|
|
logging.warning("Generation failed with sampling, retrying with greedy decoding.") |
|
|
generation_config.do_sample = False |
|
|
output_ids = self.generate( |
|
|
input_ids=input_ids, |
|
|
media=media, |
|
|
media_config=media_config, |
|
|
generation_config=generation_config, |
|
|
logits_processor=xgr_logits_processor, |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() |
|
|
return response |
|
|
|
|
|
@property |
|
|
def default_generation_config(self) -> GenerationConfig: |
|
|
generation_config = copy.deepcopy(self.generation_config or GenerationConfig()) |
|
|
if self.tokenizer.eos_token_id is None: |
|
|
raise ValueError("Tokenizer must have an EOS token") |
|
|
if generation_config.max_length == GenerationConfig().max_length: |
|
|
generation_config.max_length = self.tokenizer.model_max_length |
|
|
if generation_config.pad_token_id is None: |
|
|
generation_config.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id |
|
|
if generation_config.bos_token_id is None: |
|
|
generation_config.bos_token_id = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id |
|
|
if generation_config.eos_token_id is None: |
|
|
generation_config.eos_token_id = self.tokenizer.eos_token_id |
|
|
return generation_config |
|
|
|