#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2026 The OpenBMB Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import math
import os
import tempfile
import threading
import time
import types
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from threading import Thread
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as P
from torch import nn
from torch.nn.init import trunc_normal_
from torch.nn.utils.parametrizations import weight_norm
from tqdm import tqdm
if os.getenv("USE_FLAGOS") == "1":
import importlib
flag_gems = importlib.import_module("flag_gems") # noqa: F401
flag_gems_experimental = importlib.import_module("flag_gems.experimental_ops")
gems_rmsnorm = flag_gems_experimental.rmsnorm
class GemsRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return gems_rmsnorm(hidden_states, self.weight, self.variance_epsilon)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
from transformers.models.llama import modeling_llama
from transformers.models.qwen3 import modeling_qwen3
modeling_qwen3.Qwen3RMSNorm = GemsRMSNorm
modeling_llama.LlamaRMSNorm = GemsRMSNorm
from transformers import LlamaConfig
from transformers import LlamaModel
from transformers import PreTrainedModel
from transformers import Qwen3ForCausalLM
from transformers import Qwen3PreTrainedModel
from transformers import TextIteratorStreamer
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import EncoderDecoderCache
from transformers.cache_utils import StaticCache
from transformers.generation.logits_process import TopKLogitsWarper
from transformers.generation.logits_process import TopPLogitsWarper
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_outputs import ModelOutput
from transformers.models.whisper.configuration_whisper import WhisperConfig
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from .configuration_minicpmo import MiniCPMOConfig
from .configuration_minicpmo import MiniCPMTTSConfig
from .modeling_navit_siglip import SiglipVisionTransformer
from .processing_minicpmo import MiniCPMOProcessor
from .utils import as_dynamic_cache
from .utils import ChunkPrefillChunkGenerate
from .utils import drop_tokens_from_cache
from .utils import DuplexWindowConfig
from .utils import get_kv_cache_length
from .utils import normalize_content
from .utils import realign_rotary_suffix
from .utils import SpeculativeSnapshot
from .utils import streaming_token_decoder
from .utils import StreamingWindowConfig
from .utils import torch_clone_recursive
from .utils import TTSSamplingParams
from .utils import TTSStreamingGenerator
logger = logging.getLogger(__name__)
class MiniCPMOPreTrainedModel(Qwen3PreTrainedModel):
config_class = MiniCPMOConfig
class MiniCPMO(MiniCPMOPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.llm = Qwen3ForCausalLM(config)
self.embed_dim = self.llm.config.hidden_size
self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm
# init vision module
if self.config.init_vision:
self.vpm = self.init_vision_module()
self.vision_dim = self.vpm.embed_dim
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
# init audio module
if self.config.init_audio:
self.apm = self.init_audio_module()
audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)
self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)
self.audio_encoder_layer = -1
# init tts module
if self.config.init_tts:
self.tts = self.init_tts_module()
self.terminators = ["<|im_end|>", "<|endoftext|>"]
self.think_str = ""
if self.llm.__class__.__name__ == "Qwen3ForCausalLM":
self.think_str = "\\n\\n\\n\\n"
# for streaming
self.reset_session(reset_token2wav_cache=True)
# streaming audio processing constants
self.SAMPLE_RATE = 16000
self.CHUNK_MS = 1000 # regular chunk length (ms)
self.FIRST_CHUNK_MS = 1035 # first chunk length (ms)
self.CNN_REDUNDANCY_MS = 0 # CNN redundancy (ms)
# for sliding window
self.streaming_window_config = StreamingWindowConfig()
self.streaming_require_system_prompt = True
self.streaming_window_enabled = True
self.force_rope_reindex = False # RoPE reindex testing switch
def init_streaming_processor(self):
self.prepare_processor(processor=None, tokenizer=None)
if hasattr(self.processor, "set_streaming_mode"):
self.processor.set_streaming_mode(
mode="exact",
chunk_ms=self.CHUNK_MS,
first_chunk_ms=self.FIRST_CHUNK_MS,
cnn_redundancy_ms=self.CNN_REDUNDANCY_MS,
enable_sliding_window=True,
slide_trigger_seconds=30.0,
slide_stride_seconds=10.0,
)
self.processor.reset_streaming()
self.audio_chunk_idx = 0
def reset_session(self, reset_token2wav_cache=True):
self.llm_past_key_values = None
self.audio_past_key_values = None
self.tts_last_turn_tokens = None
self.llm_generated = False # last turn generated by llm or not
self.llm_generate_completed = False
self.new_user_msg = True
self.session_id = None
if reset_token2wav_cache:
self.token2wav_cache = None
# for sliding window
self.streaming_text_preserve = 0
self.streaming_position_offset = 0
self._rope_inv_freq_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {}
self._next_round_id = 0
self._pending_round_id = None
self._omni_chunk_history: List[Dict[str, Union[str, int]]] = []
self._round_history: List[Dict[str, Union[int, str, torch.Tensor, Optional[int]]]] = []
def init_vision_module(self):
if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
setattr(model, "embed_dim", model.embeddings.embed_dim)
setattr(model, "patch_size", model.embeddings.patch_size)
return model
def init_resampler(self, embed_dim, vision_dim):
return Resampler(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
)
def init_audio_module(self):
if self.config._attn_implementation == "eager":
self.config.audio_config._attn_implementation = "eager"
else:
# using flash_attention_2 will cause: RuntimeError: cu_seqlens_q must have shape (batch_size + 1)
self.config.audio_config._attn_implementation = "sdpa"
return MiniCPMWhisperEncoder(self.config.audio_config)
def init_tts_module(self):
if self.config._attn_implementation == "flash_attention_2":
self.config.tts_config.attn_implementation = "flash_attention_2"
else:
self.config.tts_config.attn_implementation = "eager"
return MiniCPMTTS(config=self.config.tts_config, audio_tokenizer=None)
def _ensure_asset_dir(self, asset_subpath: str, model_dir: Optional[str] = None) -> str:
"""Ensure asset directory exists, downloading from HF if needed."""
model_dir = model_dir or os.path.join(self.config._name_or_path, asset_subpath)
if not os.path.exists(model_dir):
from huggingface_hub import snapshot_download
repo_dir = snapshot_download(
repo_id="openbmb/MiniCPM-o-4_5",
allow_patterns=[f"{asset_subpath}/**"],
)
model_dir = os.path.join(repo_dir, asset_subpath)
assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}"
return model_dir
def init_tts(self, model_dir=None, enable_float16=False, n_timesteps=10, **kwargs):
if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio":
logger.warning("audio tokenizer type is set to s3tokenizer_step_audio")
self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio"
try:
from stepaudio2 import Token2wav
except ImportError:
raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]")
model_dir = self._ensure_asset_dir("assets/token2wav", model_dir)
self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps)
return self.tts.audio_tokenizer
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
def set_input_embeddings(self, value):
self.llm.embed_tokens = value
def get_output_embeddings(self):
return self.llm.lm_head
def set_output_embeddings(self, new_embeddings):
self.llm.lm_head = new_embeddings
def set_decoder(self, decoder):
self.llm = decoder
def get_decoder(self):
return self.llm
@staticmethod
def get_sys_prompt(ref_audio=None, mode="default", language="en", ref_audio_max_ms=None):
if ref_audio is not None:
if isinstance(ref_audio, str):
import os
import librosa
if os.path.isfile(ref_audio):
duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None
ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration)
else:
logger.error(f"Could not find {ref_audio}")
ref_audio = None
assert isinstance(ref_audio, np.ndarray), "ref_audio error"
if mode == "omni":
if language == "zh":
sys_prompt = ""
vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。"
vc_prompt_suffix = (
"请用这种声音风格来为用户提供帮助。 请认真、高质量地回复用户的问题。 请用高自然度的方式和用户聊天。"
)
else:
sys_prompt = ""
vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt."
vc_prompt_suffix = "As an assistant, you will speak using this voice style."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
else:
sys_msgs = {"role": "system", "content": [sys_prompt]}
return sys_msgs
elif mode == "audio_assistant":
if language == "zh":
vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。"
vc_prompt_suffix = "你的任务是用这种声音模式来当一个助手。请认真、高质量地回复用户的问题。请用高自然度的方式和用户聊天。你是由面壁智能开发的人工智能助手:面壁小钢炮。"
else:
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
vc_prompt_suffix = "Please assist users while maintaining this voice style. Please answer the user's questions seriously and in a high quality. Please chat with the user in a highly human-like and oral style. You are a helpful assistant developed by ModelBest: MiniCPM-Omni."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
else:
logger.warning(
"Warning: ref_audio is None, speech generation will be performed based on the default voice."
)
sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]}
return sys_msgs
elif mode == "audio_roleplay":
if language == "zh":
vc_prompt_prefix = "模仿输入音频中的声音特征。"
vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。"
else:
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
vc_prompt_suffix = "Try to role-play the character based on the audio prompt above."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]}
else:
sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]}
return sys_msgs
elif mode == "voice_cloning":
if language == "zh":
vc_prompt_prefix = "模仿输入音频中的声音特征。"
else:
vc_prompt_prefix = "Clone the voice in the provided audio prompt."
if ref_audio is not None:
sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio]}
else:
raise ValueError("ref_audio con't be None in voice_cloning mode.")
return sys_msgs
else:
sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text."
sys_msgs = {"role": "system", "content": [sys_prompt]}
return sys_msgs
@staticmethod
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
num_lookhead: int = 0,
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
num_lookhead:
Returns:
torch.Tensor: mask
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size)
ret[i, start:ending] = True
return ret
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""Computes the output length of the convolutional layers and the output length of the audio encoder"""
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
input_lengths_after_pooling = (
input_lengths_after_cnn - self.config.audio_pool_step
) // self.config.audio_pool_step + 1
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
return input_lengths_after_cnn, input_lengths_after_pooling
def get_vision_embedding(self, data):
if "vision_hidden_states" not in data:
dtype = self.llm.model.embed_tokens.weight.dtype
device = self.llm.model.embed_tokens.weight.device
tgt_sizes = data["tgt_sizes"]
pixel_values_list = data["pixel_values"]
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
for pixel_values in pixel_values_list:
img_cnt.append(len(pixel_values))
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
# exist image
if all_pixel_values:
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
for i in range(B):
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.type(dtype)
if B > vision_batch_size:
hs = []
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
tmp_hs = self.vpm(
all_pixel_values[start_idx:end_idx],
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
tgt_sizes=tgt_sizes[start_idx:end_idx],
).last_hidden_state
hs.append(tmp_hs)
vision_embedding = torch.cat(hs, dim=0)
else:
vision_embedding = self.vpm(
all_pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes,
).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
else: # no image
if self.training:
dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
tgt_sizes = torch.Tensor(
[
[
(224 // self.config.patch_size),
math.ceil(224 / self.config.patch_size),
]
]
).type(torch.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
for _ in range(len(pixel_values_list)):
vision_hidden_states.append(dummy_feature)
else:
vision_hidden_states = data["vision_hidden_states"]
return vision_hidden_states
def get_vllm_embedding(self, data):
vision_hidden_states = self.get_vision_embedding(data)
if hasattr(self.llm.config, "scale_emb"):
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
else:
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
vision_hidden_states = [
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
]
bs = len(data["input_ids"])
for i in range(bs):
cur_vs_hs = vision_hidden_states[i]
if len(cur_vs_hs) > 0:
cur_vllm_emb = vllm_embedding[i]
cur_image_bound = data["image_bound"][i]
if len(cur_image_bound) > 0:
image_indices = torch.stack(
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
).to(vllm_embedding.device)
cur_vllm_emb.scatter_(
0,
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
)
elif self.training:
cur_vllm_emb += cur_vs_hs[0].mean() * 0
return vllm_embedding, vision_hidden_states
def get_audio_embedding_streaming(
self,
data,
use_extra_context=False,
prefix_extra_frames=1,
suffix_extra_frames=1,
cnn_min_length=None,
):
"""Extract audio embeddings in a streaming manner using cached key-value pairs.
This method processes incoming audio features incrementally and stores/updates `past_key_values`
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios.
Args:
data (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
use_extra_context (bool): If True, assumes input contains extra frames for CNN context.
prefix_extra_frames (int): Number of prefix extra frames.
suffix_extra_frames (int): Number of suffix extra frames.
cnn_min_length (int): Minimum length for CNN input padding.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
# exist audio
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
assert batch_size == 1
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# whisper's past_key_values management (core)
if self.audio_past_key_values is not None:
cache_length = self.audio_past_key_values[0][0].shape[2]
apm_max_len = self.apm.embed_positions.weight.shape[0]
if cache_length + max_seq_len >= apm_max_len:
logger.warning(
f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset."
)
self.audio_past_key_values = None
# build attention mask (bidirectional attention, same as offline mode)
batch_size, _, max_mel_seq_len = wavforms.shape
current_seq_len = (max_mel_seq_len - 1) // 2 + 1
# if use extra context, need to adjust sequence length
if use_extra_context:
# calculate actual sequence length after removing redundancy
# conv2's stride=2, so the mapping from mel frames to output frames is ceil(x/2)
prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0
suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0
current_seq_len = current_seq_len - prefix_to_remove - suffix_to_remove
# calculate history length (if there is KV cache)
if self.audio_past_key_values is not None:
past_len = self.audio_past_key_values[0][0].shape[2] # get history sequence length
total_seq_len = past_len + current_seq_len
else:
past_len = 0
total_seq_len = current_seq_len
# create bidirectional attention mask (full attention)
audio_attention_mask = torch.zeros(
(batch_size, 1, current_seq_len, total_seq_len),
dtype=self.apm.conv1.weight.dtype,
device=wavforms.device,
)
# Step 1: APM processing
audio_outputs = self.apm(
wavforms,
past_key_values=self.audio_past_key_values,
use_cache=True,
output_hidden_states=True,
attention_mask=audio_attention_mask,
use_extra_context=use_extra_context,
prefix_extra_frames=prefix_extra_frames,
suffix_extra_frames=suffix_extra_frames,
cnn_min_length=cnn_min_length,
)
if hasattr(self, "audio_encoder_layer"):
audio_states = audio_outputs.hidden_states[self.audio_encoder_layer]
else:
audio_states = audio_outputs.last_hidden_state
self.audio_past_key_values = audio_outputs.past_key_values
# Step 2: Projection
audio_embeds = self.audio_projection_layer(audio_states)
# Step 3: Pooling
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
else:
return final_audio_embeds
else:
return []
def get_audio_embedding(self, data, chunk_length=-1, dummy=True):
wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]]
if len(wavforms) > 0:
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
# Create a sequence tensor of shape (batch_size, max_seq_len)
seq_range = (
torch.arange(
0,
max_seq_len,
dtype=audio_feature_lens.dtype,
device=audio_feature_lens.device,
)
.unsqueeze(0)
.expand(batch_size, max_seq_len)
)
lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
# Create mask
padding_mask = seq_range >= lengths_expand # 1 for padded values
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
batch_size, 1, max_seq_len, max_seq_len
)
audio_attention_mask = audio_attention_mask_.to(
dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
)
if chunk_length > 0:
chunk_num_frame = int(chunk_length * 50)
chunk_mask = self.subsequent_chunk_mask(
size=max_seq_len,
chunk_size=chunk_num_frame,
num_left_chunks=-1,
device=audio_attention_mask_.device,
)
audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask))
audio_attention_mask[audio_attention_mask_] = float("-inf")
audio_states = self.apm(
wavforms, output_hidden_states=True, attention_mask=audio_attention_mask
).hidden_states[self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens)
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
elif self.training and dummy:
dtype = self.apm.embed_positions.weight.dtype
device = self.apm.embed_positions.weight.device
dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype)
audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer]
audio_embeds = self.audio_projection_layer(audio_states)
audio_embeds = audio_embeds.transpose(1, 2)
audio_embeds = self.audio_avg_pooler(audio_embeds)
audio_embeds = audio_embeds.transpose(1, 2)
return [audio_embeds]
else:
return []
def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False):
"""
Args:
data:
input_embeddings:
chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding or not
Returns:
final embeddings with audio feature
"""
if stream_input:
audio_embeddings = self.get_audio_embedding_streaming(data)
else:
audio_embeddings = self.get_audio_embedding(data, chunk_length)
bs = len(input_embeddings)
if len(data.get("audio_features", [])) > 0:
assert len(audio_embeddings) == len(input_embeddings)
if len(audio_embeddings) > 0:
audio_bounds = data["audio_bounds"]
if self.config.stream_input:
assert bs == 1, "audio stream_input mode only support batch size 1"
for i in range(bs):
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeddings.device, dtype=input_embeddings.dtype
)
audio_start_pos = 0
for bound in audio_bounds[i]:
audio_len = bound[1] - bound[0]
input_embeddings[i, bound[0] : bound[1]] = audio_embs[
audio_start_pos : audio_start_pos + audio_len, :
]
audio_start_pos += audio_len
else:
for i in range(bs):
audio_embs = audio_embeddings[i]
bounds = audio_bounds[i]
for embs, bound in zip(audio_embs, bounds):
audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to(
input_embeddings.device
)
if embs.shape[0] != len(audio_indices):
raise ValueError(
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}"
)
input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype)
elif self.training:
for i in range(bs):
# dummy audio_embedings
input_embeddings += audio_embeddings[0].mean() * 0
return input_embeddings
def forward(self, data, **kwargs):
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
vllm_embedding = self.get_omni_embedding(
data,
input_embeddings=vllm_embedding,
chunk_length=self.config.audio_chunk_length,
)
position_ids = data["position_ids"]
if position_ids.dtype != torch.int64:
position_ids = position_ids.long()
return self.llm(
input_ids=None,
position_ids=position_ids,
inputs_embeds=vllm_embedding,
**kwargs,
)
def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict_in_generate=True,
**kwargs,
)
return outputs
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_config = {
"inputs_embeds": inputs_embeds,
"pad_token_id": 0,
"eos_token_id": terminators,
"streamer": streamer,
}
generation_config.update(kwargs)
thread = Thread(target=self.llm.generate, kwargs=generation_config)
thread.start()
return streamer
def _decode_text(self, result_ids, tokenizer):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
result_text = []
for result in result_ids:
result = result[result != 0]
if result[0] == tokenizer.bos_id:
result = result[1:]
if result[-1] in terminators:
result = result[:-1]
result_text.append(tokenizer.decode(result))
return result_text
@torch.inference_mode()
def generate(
self,
input_ids=None,
pixel_values=None,
tgt_sizes=None,
audio_features=None,
audio_feature_lens=None,
image_bound=None,
audio_bounds=None,
spk_bounds=None,
attention_mask=None,
tokenizer=None,
vision_hidden_states=None,
stream=False,
**kwargs,
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
model_inputs = {
"input_ids": input_ids,
"audio_features": audio_features,
"audio_feature_lens": audio_feature_lens,
"image_bound": image_bound,
"audio_bounds": audio_bounds,
"spk_bounds": spk_bounds,
}
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
model_inputs["tgt_sizes"] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
with torch.inference_mode():
model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs)
model_inputs["inputs_embeds"] = self.get_omni_embedding(
model_inputs,
input_embeddings=model_inputs["inputs_embeds"],
chunk_length=self.config.audio_chunk_length,
)
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
outputs = {} # if stream return TextIteratorStreamer and output is empty
else:
outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs)
result = self._decode_text(outputs.sequences, tokenizer)
return result, outputs
def _build_streaming_mask(self, tts_tokens_len):
tts_sequence_full_length = 1 + self.tts.streaming_text_reserved_len + 1
streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8)
streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1
streaming_attention_mask[-1] = 1
return streaming_attention_mask
def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048):
spk_embeds = self._get_last_spk_embeds(inputs, outputs)
text = text.split("<|tts_bos|>")[-1]
gen_text = text.split("<|tts_eos|>")[0]
tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long)
streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)
logits_warpers, logits_processors = gen_logits(
num_code=626,
top_p=self.tts.top_p,
top_k=self.tts.top_k,
repetition_penalty=self.tts.repetition_penalty,
)
condition_length = 1 + self.tts.streaming_text_reserved_len + 1
dtype = self.tts.emb_text.weight.dtype
emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device)
past_key_values = [
(
torch.zeros(
1,
self.tts.config.num_attention_heads,
condition_length - 1,
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
dtype=emb.dtype,
device=self.tts.device,
),
torch.zeros(
1,
self.tts.config.num_attention_heads,
condition_length - 1,
self.tts.config.hidden_size // self.tts.config.num_attention_heads,
dtype=emb.dtype,
device=self.tts.device,
),
)
for _ in range(self.tts.config.num_hidden_layers)
]
audio_input_ids = torch.zeros(
1,
condition_length,
self.tts.num_vq,
dtype=torch.long,
device=self.tts.device,
)
eos_lab = False
for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)):
if chunk_idx == 0:
begin = chunk_idx * self.tts.streaming_text_chunk_size + 0
end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1
else:
begin = chunk_idx * self.tts.streaming_text_chunk_size + 1
end = min(
(chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1,
condition_length - 1,
)
if end - begin > 0:
text_input_ids = tts_input_ids[:, begin:end]
position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)
if begin == 0:
past_key_values = self.tts.prefill_text(
input_ids=text_input_ids,
position_ids=position_ids,
past_key_values=past_key_values,
lm_spk_emb_last_hidden_states=spk_embeds,
)
else:
past_key_values = self.tts.prefill_text(
input_ids=text_input_ids,
position_ids=position_ids,
past_key_values=past_key_values,
)
outputs = self.tts.generate(
input_ids=audio_input_ids,
past_key_values=past_key_values,
streaming_tts_text_mask=streaming_tts_text_mask,
max_new_token=output_chunk_size,
force_no_stop=self.force_no_stop,
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
logits_warpers=logits_warpers,
logits_processors=logits_processors,
)
audio_input_ids = outputs.audio_input_ids
past_key_values = outputs.past_key_values
if outputs.finished:
eos_lab = True
break
if not eos_lab:
while True:
outputs = self.tts.generate(
input_ids=audio_input_ids,
past_key_values=past_key_values,
streaming_tts_text_mask=streaming_tts_text_mask,
max_new_token=output_chunk_size,
force_no_stop=self.force_no_stop,
temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
logits_warpers=logits_warpers,
logits_processors=logits_processors,
)
audio_input_ids = outputs.audio_input_ids
past_key_values = outputs.past_key_values
if outputs.finished:
break
if outputs.new_ids.shape[1] > tts_max_new_tokens:
break
@staticmethod
def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs):
num_beams = kwargs.get("num_beams", 3)
generation_config = {
"num_beams": num_beams,
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"repetition_penalty": 1.02,
}
if do_sample:
generation_config.update(
{
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"repetition_penalty": 1.02,
}
)
elif num_beams > 1:
generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False})
else:
generation_config.update({"do_sample": False, "repetition_penalty": 1.02})
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
generation_config["min_new_tokens"] = min_new_tokens
generation_config["max_new_tokens"] = max_new_tokens
return generation_config
def prepare_processor(self, processor=None, tokenizer=None):
if processor is not None:
self.processor = processor
if not hasattr(self, "processor") or self.processor is None:
self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
if tokenizer is not None:
self.processor.tokenizer = tokenizer
@torch.inference_mode()
def chat(
self,
image=None,
msgs=None,
vision_hidden_states=None,
max_new_tokens=4096,
min_new_tokens=0,
do_sample=True,
max_inp_length=8192,
max_slice_nums=None,
use_image_id=None,
enable_thinking=False,
use_tts_template=False,
generate_audio=False,
output_audio_path=None,
output_tts_inputs_embeds_path=None,
omni_mode=False,
teacher_forcing=False,
return_prompt=False,
tts_proj_layer=-1,
tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
merge_audio_from_same_content=True,
stream=False,
stream_input=False,
tokenizer=None,
processor=None,
**kwargs,
):
from PIL import Image
batched = isinstance(msgs[0], list)
msgs_list = msgs
images_list = image
if not batched:
images_list, msgs_list = [images_list], [msgs_list]
else:
assert images_list is None, "Please integrate image to msgs when using batch inference."
images_list = [None] * len(msgs_list)
assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
self.prepare_processor(processor=processor, tokenizer=tokenizer)
prompts_lists = []
input_images_list = []
input_audios_list = []
audio_parts_list = []
for image, msgs in zip(images_list, msgs_list):
if isinstance(msgs, str):
msgs = json.loads(msgs)
copy_msgs = deepcopy(msgs)
assert len(msgs) > 0, "msgs is empty"
assert do_sample or not stream, "if use stream mode, make sure do_sample=True"
if image is not None and isinstance(copy_msgs[0]["content"], str):
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
images = []
audios = []
audio_parts = []
for i, msg in enumerate(copy_msgs):
role = msg["role"]
content = msg["content"]
assert role in ["system", "user", "assistant"]
if i == 0:
assert role in ["user", "system"], "The role of first msg should be user"
# Normalize structured content (OpenAI format) to native format
content = normalize_content(content)
cur_msgs = []
for c in content:
if isinstance(c, Image.Image):
images.append(c)
cur_msgs.append("./")
elif isinstance(c, np.ndarray): # audio
audios.append(c)
audio_parts.append(i)
cur_msgs.append("")
use_tts_template = True
elif isinstance(c, str):
cur_msgs.append(c)
if omni_mode or stream_input:
msg["content"] = "".join(cur_msgs)
else:
msg["content"] = "\n".join(cur_msgs)
prompts_lists.append(
self.processor.tokenizer.apply_chat_template(
copy_msgs,
tokenize=False,
add_generation_prompt=False if teacher_forcing else True,
use_tts_template=use_tts_template,
enable_thinking=enable_thinking,
)
)
input_images_list.append(images)
input_audios_list.append(audios)
audio_parts_list.append(audio_parts)
if not merge_audio_from_same_content:
audio_parts_list = None
inputs = self.processor(
prompts_lists,
input_images_list,
input_audios_list,
audio_parts_list,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
stream_input=stream_input,
return_tensors="pt",
max_length=max_inp_length,
).to(self.device)
generation_config = self.prepare_generation_config(
do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs
)
generation_config.pop("max_new_tokens", None)
inputs.pop("image_sizes")
# teacher_forcing = True => generate audio with given text
with torch.inference_mode():
res, outputs = self.generate(
**inputs,
tokenizer=self.processor.tokenizer,
max_new_tokens=1 if teacher_forcing else max_new_tokens,
vision_hidden_states=vision_hidden_states,
stream=stream,
**generation_config,
)
# spk bound and tts bound
tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
# Combine input_ids and generated sequences to get complete sequence
input_ids = inputs["input_ids"][0]
generated_ids = outputs.sequences[0]
# Combine by concatenating input_ids with the new tokens from generated sequence
full_sequence = torch.cat([input_ids, generated_ids])
# Update the sequences in outputs
full_sequences = full_sequence.unsqueeze(0)
outputs["full_sequences"] = full_sequences
tts_bos_indices = []
tts_eos_indices = []
for i, x in enumerate(full_sequences[0]):
if x == tts_bos_token:
# tts_bos + 1 is the position of the first tts, so that it is convenient to slice hidden states for tts
tts_bos_indices.append(i + 1)
elif x == tts_eos_token:
if teacher_forcing and i == len(full_sequences[0]) - 1:
continue
tts_eos_indices.append(i)
tts_bos_idx = tts_bos_indices[-1] if tts_bos_indices else -1
# Use None instead of -1 when no EOS token found, so that slice [start:None]
# means "to the end" rather than [start:-1] which excludes the last element
tts_eos_idx = tts_eos_indices[-1] if tts_eos_indices else None
tts_bound = (tts_bos_idx, tts_eos_idx)
answer = res[0]
if answer is not None:
answer = answer.split("<|tts_eos|>")[0]
if use_tts_template and generate_audio and output_audio_path:
import soundfile as sf
try:
generated_waveform = self._generate_speech_non_streaming(
outputs=outputs,
tts_bound=tts_bound,
tts_proj_layer=tts_proj_layer,
audio_prompt=(
input_audios_list[0][0]
if len(input_audios_list) > 0 and len(input_audios_list[0]) > 0
else None
),
output_tts_inputs_embeds_path=output_tts_inputs_embeds_path,
tts_sampling_params=tts_sampling_params,
)
if isinstance(generated_waveform, torch.Tensor):
sf.write(output_audio_path, generated_waveform.cpu().numpy(), samplerate=24000)
elif isinstance(generated_waveform, np.ndarray):
sf.write(output_audio_path, generated_waveform, samplerate=24000)
logger.debug(f"audio saved to {output_audio_path}")
except:
import traceback
traceback.print_exc()
if return_prompt:
return answer, prompts_lists[0]
else:
return answer
@torch.inference_mode()
def _generate_speech_non_streaming(
self,
outputs,
tts_bound,
tts_proj_layer,
audio_prompt,
output_tts_inputs_embeds_path=None,
tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
):
last_hidden_states = [hs[tts_proj_layer] for hs in outputs.hidden_states]
last_hidden_states = torch.vstack([i[0] for i in last_hidden_states])
spk_embeds = (
torch.ones([0, self.tts.config.hidden_size]).to(last_hidden_states.device).to(last_hidden_states.dtype)
)
if self.tts.condition_type == "hidden_text_merge":
llm_tokens = outputs["full_sequences"][0][tts_bound[0] : tts_bound[1]]
llm_tokens = torch.tensor(llm_tokens, device=self.tts.emb_text.weight.device, dtype=torch.long)
llm_embeds = self.tts.emb_text(llm_tokens) # make sure emb_text is compatible with llm vocab size
hidden_embeds = last_hidden_states[tts_bound[0] : tts_bound[1]]
hidden_embeds = self.tts.projector_semantic(hidden_embeds)
if self.tts.config.normalize_projected_hidden:
hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
tts_embeds = llm_embeds + hidden_embeds
else:
raise NotImplementedError
audio_bos = [self.tts.audio_bos_token_id]
audio_bos = torch.tensor(audio_bos, device=self.tts.emb_text.weight.device, dtype=torch.long)
audio_bos_embeds = self.tts.emb_text(audio_bos)
text_eos_embed = self.tts.emb_text(
torch.tensor(
[self.tts.config.text_eos_token_id],
device=self.tts.emb_text.weight.device,
dtype=torch.long,
)
)
inputs_embeds = torch.cat([spk_embeds, tts_embeds, text_eos_embed, audio_bos_embeds], dim=0).unsqueeze(0)
# save inputs_embeds to file
if output_tts_inputs_embeds_path:
torch.save(inputs_embeds, output_tts_inputs_embeds_path)
outputs = self.tts.generate(
inputs_embeds=inputs_embeds,
sampling_params=tts_sampling_params,
eos_token=torch.tensor(
[self.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=self.tts.device,
),
)
import io
import soundfile as sf
generated_tokens = outputs.new_ids.squeeze(-1)
reference_audio = audio_prompt
prompt_wav_path = None
if reference_audio is not None:
logger.debug("use reference audio in data to generate waveform")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
prompt_wav_path = tmp_wav.name
sf.write(prompt_wav_path, reference_audio, 16000)
wav_bytes = self.tts.audio_tokenizer(
generated_tokens.squeeze(0).tolist(),
prompt_wav_path,
)
# convert wav bytes back to tensor for caller compatibility
waveform, sr = sf.read(io.BytesIO(wav_bytes))
return torch.tensor(waveform, dtype=torch.float32)
@torch.inference_mode()
def init_token2wav_cache(self, prompt_speech_16k):
import soundfile as sf
if hasattr(self.tts.audio_tokenizer, "set_stream_cache"):
self.tts.audio_tokenizer.cache = None
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
prompt_wav_path = tmp_wav.name
sf.write(prompt_wav_path, prompt_speech_16k, 16000)
flow_cache_base, hift_cache_base = self.tts.audio_tokenizer.set_stream_cache(prompt_wav_path)
self.token2wav_cache = {
"flow_cache_base": torch_clone_recursive(flow_cache_base),
"hift_cache_base": torch_clone_recursive(hift_cache_base),
}
else:
model_input = self.tts.audio_tokenizer.frontend.frontend_token2wav(
speech_tokens=torch.zeros(1, 1, dtype=torch.long, device=self.tts.device),
speech_16k=None,
prompt_speech_16k=prompt_speech_16k,
resample_rate=self.tts.audio_tokenizer.sample_rate,
prompt_speech=None,
)
prompt_token = model_input["flow_prompt_speech_token"]
prompt_feat = model_input["prompt_speech_feat"]
embedding = model_input["flow_embedding"]
if self.tts.audio_tokenizer.fp16:
prompt_feat = prompt_feat.to(torch.half)
embedding = embedding.to(torch.half)
prepared_cache = self.tts.audio_tokenizer.model.prepare_cache_from_prompt(
prompt_token=prompt_token,
prompt_feat=prompt_feat,
embedding=embedding,
n_timesteps=self.tts.config.s3_stream_n_timesteps,
code_chunk_size=self.tts.config.s3_stream_chunk_size,
chunk_prelook_size=self.tts.config.s3_stream_prelook_size,
use_attn_idx=False,
)
self.token2wav_cache = prepared_cache
# for sliding window
def _ensure_dynamic_cache(self):
cache = self.llm_past_key_values
if cache is None:
return None
cache = as_dynamic_cache(cache)
if isinstance(cache, DynamicCache):
self.llm_past_key_values = cache
return cache
return None
def _get_kv_cache_length(self, cache=None):
cache = cache if cache is not None else self.llm_past_key_values
return get_kv_cache_length(cache)
# todo: not-used del?
def _rebuild_cache_from_history(self):
preserved_ids: List[torch.Tensor] = []
for entry in self._omni_chunk_history:
ids = entry.get("input_ids")
if ids is None or not isinstance(ids, torch.Tensor) or ids.numel() == 0:
continue
preserved_ids.append(ids.to(self.device))
if not preserved_ids:
self.llm_past_key_values = None
self.streaming_position_offset = 0
self._rope_inv_freq_cache.clear()
return
concat_ids = torch.cat(preserved_ids, dim=1)
attention_mask = torch.ones((1, concat_ids.shape[1]), dtype=torch.bool, device=self.device)
outputs = self.llm(
input_ids=concat_ids,
attention_mask=attention_mask,
use_cache=True,
return_dict=True,
)
self.llm_past_key_values = outputs.past_key_values
self.streaming_position_offset = 0
self._rope_inv_freq_cache.clear()
def _get_rope_theta(self) -> float:
return float(getattr(self.llm.config, "rope_theta", 10000.0))
def _realign_rotary_suffix(
self,
suffix_keys: torch.Tensor,
old_positions: torch.Tensor,
new_positions: torch.Tensor,
) -> torch.Tensor:
return realign_rotary_suffix(
suffix_keys,
old_positions,
new_positions,
rope_theta=self._get_rope_theta(),
inv_freq_cache=self._rope_inv_freq_cache,
)
def _encode_text(self, tokenizer, text) -> Optional[torch.Tensor]:
if tokenizer is None or not text:
return None
ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
return ids.to(self.device)
@staticmethod
def _safe_decode(tokenizer, input_ids):
if tokenizer is None or input_ids is None:
return None
if isinstance(input_ids, torch.Tensor):
ids = input_ids.cpu().tolist()
if ids and isinstance(ids[0], list):
ids = ids[0]
else:
ids = input_ids
try:
return tokenizer.decode(ids, skip_special_tokens=False)
except Exception:
return None
def _finalize_round(
self, round_id: Optional[int], cache_before: int, assistant_input_ids: Optional[torch.Tensor] = None
):
if round_id is None:
self._pending_round_id = None
return
cache_after = self._get_kv_cache_length()
if assistant_input_ids is not None:
assistant_len = assistant_input_ids.shape[1]
else:
assistant_len = max(cache_after - cache_before, 0)
if assistant_len > 0:
self._register_chunk(
assistant_len,
"assistant",
round_id=round_id,
input_ids=assistant_input_ids,
tokenizer=self.processor.tokenizer if hasattr(self, "processor") else None,
)
self._pending_round_id = None
self._next_round_id += 1
def _register_chunk(
self,
seq_len: int,
chunk_type: str,
*,
round_id: int,
input_ids=None,
tokenizer=None,
) -> None:
if seq_len <= 0:
return
entry = {"length": int(seq_len), "type": chunk_type, "round": round_id}
if input_ids is not None:
entry["input_ids"] = input_ids.clone().detach()
entry["decoded"] = self._safe_decode(tokenizer, entry["input_ids"])
else:
entry["input_ids"] = None
entry["decoded"] = None
self._omni_chunk_history.append(entry)
if chunk_type == "system":
self.streaming_text_preserve = max(self.streaming_text_preserve, entry["length"])
def _drop_tokens_from_cache(self, length: int, cache: DynamicCache) -> bool:
"""Drop tokens from cache using the utility function."""
_, new_offset, success = drop_tokens_from_cache(
cache=cache,
length=length,
preserve=self.streaming_text_preserve,
position_offset=self.streaming_position_offset,
rope_theta=self._get_rope_theta(),
inv_freq_cache=self._rope_inv_freq_cache,
)
if success:
self.streaming_position_offset = new_offset
return success
def _drop_next_round(self, cache: DynamicCache) -> bool:
seen_rounds = set()
for entry in self._omni_chunk_history:
round_id = entry.get("round")
if round_id is None or round_id in seen_rounds:
continue
seen_rounds.add(round_id)
round_entries = [e for e in self._omni_chunk_history if e.get("round") == round_id]
if any(e.get("type") == "system" for e in round_entries):
continue
if self._drop_round(round_id, cache):
return True
return False
def _drop_round(self, round_id: int, cache: DynamicCache) -> bool:
entries = [e for e in self._omni_chunk_history if e.get("round") == round_id]
if not entries:
return False
total_len = sum(e["length"] for e in entries)
if total_len <= 0:
for e in entries:
self._omni_chunk_history.remove(e)
return False
if not self._drop_tokens_from_cache(total_len, cache):
return False
for e in entries:
self._omni_chunk_history.remove(e)
return True
def _enforce_text_window(self) -> None:
if not self.streaming_window_enabled:
return
cache = self._ensure_dynamic_cache()
if cache is None:
return
high_limit = max(0, int(self.streaming_window_config.text_window_high_tokens))
low_limit = max(0, int(self.streaming_window_config.text_window_low_tokens))
if high_limit <= 0:
return
target = max(0, low_limit)
total_len = self._get_kv_cache_length(cache)
if total_len <= high_limit:
return
dropped_any = False
while total_len > target:
if not self._drop_next_round(cache):
break
dropped_any = True
total_len = self._get_kv_cache_length(cache)
# snapshot, vad
def save_speculative_snapshot(self) -> SpeculativeSnapshot:
"""Internal method: save speculative snapshot.
Called at the start of streaming_generate, saves to self._speculative_snapshot.
Save strategy:
- LLM KV Cache: only record length (restore by truncation, zero extra VRAM)
- Audio KV Cache: deep clone (as generate sets it to None)
- Mel processor: full state snapshot (including buffer)
"""
# get LLM cache information
llm_cache_length = self._get_kv_cache_length()
llm_cache_checksum = None
if self.llm_past_key_values is not None and hasattr(self.llm_past_key_values, "key_cache"):
if len(self.llm_past_key_values.key_cache) > 0:
llm_cache_checksum = self.llm_past_key_values.key_cache[0].sum().item()
# get audio cache length and clone audio_past_key_values
audio_cache_length = 0
audio_cache_checksum = None
audio_past_key_values_clone = None
if self.audio_past_key_values is not None:
# handle DynamicCache format (Whisper encoder may return this format)
if isinstance(self.audio_past_key_values, DynamicCache):
if hasattr(self.audio_past_key_values, "key_cache") and len(self.audio_past_key_values.key_cache) > 0:
audio_cache_length = self.audio_past_key_values.key_cache[0].shape[2]
audio_cache_checksum = self.audio_past_key_values.key_cache[0].sum().item()
# deep clone DynamicCache
cloned_cache = DynamicCache()
for k, v in zip(self.audio_past_key_values.key_cache, self.audio_past_key_values.value_cache):
cloned_cache.update(k.clone(), v.clone(), layer_idx=len(cloned_cache.key_cache))
audio_past_key_values_clone = cloned_cache
# handle EncoderDecoderCache format
elif isinstance(self.audio_past_key_values, EncoderDecoderCache):
self_attn_cache = self.audio_past_key_values.self_attention_cache
if hasattr(self_attn_cache, "key_cache") and len(self_attn_cache.key_cache) > 0:
audio_cache_length = self_attn_cache.key_cache[0].shape[2]
audio_cache_checksum = self_attn_cache.key_cache[0].sum().item()
# deep clone EncoderDecoderCache
cloned_self_attn = DynamicCache()
if hasattr(self_attn_cache, "key_cache"):
for k, v in zip(self_attn_cache.key_cache, self_attn_cache.value_cache):
cloned_self_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_self_attn.key_cache))
cross_attn_cache = self.audio_past_key_values.cross_attention_cache
cloned_cross_attn = DynamicCache()
if hasattr(cross_attn_cache, "key_cache"):
for k, v in zip(cross_attn_cache.key_cache, cross_attn_cache.value_cache):
cloned_cross_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_cross_attn.key_cache))
audio_past_key_values_clone = EncoderDecoderCache(cloned_self_attn, cloned_cross_attn)
# handle tuple format (compatible with old format)
elif isinstance(self.audio_past_key_values, tuple) and len(self.audio_past_key_values) > 0:
audio_cache_length = self.audio_past_key_values[0][0].shape[2]
audio_cache_checksum = self.audio_past_key_values[0][0].sum().item()
# deep clone audio_past_key_values (tuple of tuples of tensors)
audio_past_key_values_clone = tuple(
tuple(t.clone() for t in layer_cache) for layer_cache in self.audio_past_key_values
)
# get mel processor snapshot
mel_processor_snapshot = None
mel_buffer_checksum = None
if hasattr(self, "processor") and self.processor is not None:
mel_processor_snapshot = self.processor.get_streaming_snapshot()
if mel_processor_snapshot:
buf = mel_processor_snapshot.get("buffer")
if buf is not None and len(buf) > 0:
mel_buffer_checksum = float(buf.sum())
# save RNG state (important: for deterministic dithering and other random operations after restoration)
rng_state_cpu = torch.get_rng_state()
rng_state_cuda = None
if torch.cuda.is_available() and self.device.type == "cuda":
rng_state_cuda = torch.cuda.get_rng_state(self.device)
# create snapshot
snapshot = SpeculativeSnapshot(
llm_cache_length=llm_cache_length,
audio_cache_length=audio_cache_length,
new_user_msg=self.new_user_msg,
llm_generated=self.llm_generated,
llm_generate_completed=self.llm_generate_completed,
next_round_id=self._next_round_id,
pending_round_id=self._pending_round_id,
omni_chunk_history_length=len(self._omni_chunk_history),
tts_last_turn_tokens=self.tts_last_turn_tokens.clone() if self.tts_last_turn_tokens is not None else None,
audio_chunk_idx=self.audio_chunk_idx,
mel_processor_snapshot=mel_processor_snapshot,
audio_past_key_values=audio_past_key_values_clone,
timestamp=time.time(),
# debug fields
llm_cache_checksum=llm_cache_checksum,
audio_cache_checksum=audio_cache_checksum,
mel_buffer_checksum=mel_buffer_checksum,
# RNG state
rng_state_cpu=rng_state_cpu,
rng_state_cuda=rng_state_cuda,
)
return snapshot
def restore_speculative_snapshot(self, snapshot=None) -> bool:
"""Restore speculative snapshot - called when VAD speculation fails.
Restores model state to before streaming_generate was called,
allowing continued streaming_prefill for newly arrived audio.
Notes:
- Snapshot is saved when streaming_generate is called with enable_speculative_snapshot=True
- This method uses the most recent snapshot for restoration
- Snapshot is cleared after restore, cannot be called repeatedly
Returns:
bool: Whether restoration was successful
"""
snapshot = snapshot or getattr(self, "_speculative_snapshot", None)
if snapshot is None:
return False
try:
current_cache_length = self._get_kv_cache_length()
current_history_length = len(self._omni_chunk_history)
# 1. truncate LLM KV Cache
if current_cache_length > snapshot.llm_cache_length:
self._truncate_llm_cache(snapshot.llm_cache_length)
# 2. restore Audio KV Cache (important: restore from cloned copy)
# because streaming_generate will set audio_past_key_values to None
self.audio_past_key_values = snapshot.audio_past_key_values
# 3. restore session state
self.new_user_msg = snapshot.new_user_msg
self.llm_generated = snapshot.llm_generated
self.llm_generate_completed = snapshot.llm_generate_completed
# 4. restore Round management
self._next_round_id = snapshot.next_round_id
self._pending_round_id = snapshot.pending_round_id
# 5. truncate chunk history
if current_history_length > snapshot.omni_chunk_history_length:
self._omni_chunk_history = self._omni_chunk_history[: snapshot.omni_chunk_history_length]
# 6. restore TTS state
self.tts_last_turn_tokens = snapshot.tts_last_turn_tokens
# 7. restore streaming processor state
self.audio_chunk_idx = snapshot.audio_chunk_idx
# 8. restore mel processor state (important: otherwise subsequent prefill will fail due to frame number mismatch)
if (
snapshot.mel_processor_snapshot is not None
and hasattr(self, "processor")
and self.processor is not None
):
self.processor.restore_streaming_snapshot(snapshot.mel_processor_snapshot)
# 9. restore RNG state (important: ensure determinism of dithering and other random operations after restoration)
if snapshot.rng_state_cpu is not None:
torch.set_rng_state(snapshot.rng_state_cpu)
if snapshot.rng_state_cuda is not None and torch.cuda.is_available():
torch.cuda.set_rng_state(snapshot.rng_state_cuda, self.device)
# 10. clean up temporary states generated during generation
if hasattr(self, "_streaming_generated_token_ids"):
del self._streaming_generated_token_ids
if hasattr(self, "_last_streaming_text"):
del self._last_streaming_text
# 11. clear snapshot (can only be restored once)
self._speculative_snapshot = None
return True
except Exception as e:
import traceback
logger.error(traceback.format_exc())
return False
def has_speculative_snapshot(self) -> bool:
return getattr(self, "_speculative_snapshot", None) is not None
def clear_speculative_snapshot(self) -> None:
if hasattr(self, "_speculative_snapshot"):
self._speculative_snapshot = None
def _truncate_llm_cache(self, target_length: int) -> None:
if self.llm_past_key_values is None:
return
cache = self._ensure_dynamic_cache()
if cache is None:
return
current_length = self._get_kv_cache_length(cache)
if current_length <= target_length:
return
# truncate each layer of cache
for layer_idx in range(len(cache.key_cache)):
if cache.key_cache[layer_idx].numel() > 0:
cache.key_cache[layer_idx] = cache.key_cache[layer_idx][:, :, :target_length, :].contiguous()
cache.value_cache[layer_idx] = cache.value_cache[layer_idx][:, :, :target_length, :].contiguous()
# update cache metadata
cache.crop(target_length)
cache._seen_tokens = target_length
@torch.inference_mode()
def streaming_prefill(
self,
session_id,
msgs,
omni_mode=True,
max_slice_nums=None,
use_tts_template=True,
enable_thinking=False,
is_last_chunk=False, # for audio chunk, if is the last chunk, set to True
tokenizer=None,
processor=None,
**kwargs,
):
from PIL import Image
assert session_id is not None, "session_id cannot be None"
self.is_first = self.session_id is None or session_id != self.session_id
self.prepare_processor(processor=processor, tokenizer=tokenizer)
images = []
audios = []
assert len(msgs) == 1
copy_msgs = deepcopy(msgs)
msg = copy_msgs[0]
assert msg["role"] in ["system", "user", "assistant"]
is_not_system_prefill = msg["role"] != "system"
content = msg["content"]
cur_msgs = []
for j, c in enumerate(content):
if isinstance(c, Image.Image):
images.append(c)
cur_msgs.append("./")
elif isinstance(c, np.ndarray):
audios.append(c)
cur_msgs.append("")
elif isinstance(c, str):
cur_msgs.append(c)
else:
logger.error(f"Invalid content type: {c}, ignore it.")
cur_contents = "".join(cur_msgs) if omni_mode else "\n".join(cur_msgs)
if msg["role"] in ["system", "assistant"]:
self.new_user_msg = True
self.audio_past_key_values = None
if self.is_first:
self.reset_session(reset_token2wav_cache=False)
self.session_id = session_id
self.init_streaming_processor()
if msg["role"] == "user":
# no system prefill, the first segment of the first user turn
# do not use apply_chat_template, manually build prompt to avoid automatic addition of <|im_end|>
prompt = "<|im_start|>user\n" + cur_contents
self.new_user_msg = False # mark subsequent segments do not need to add user prefix anymore
else:
# system or assistant prefill, use apply_chat_template
msg["content"] = cur_contents
prompt = self.processor.tokenizer.apply_chat_template(
copy_msgs,
tokenize=False,
add_generation_prompt=False,
use_tts_template=use_tts_template,
enable_thinking=enable_thinking,
)
add_special_tokens = True # add bos
else:
# non-first prefill
if self.new_user_msg and msg["role"] == "user":
# the first segment of the new user turn
if self.llm_generated:
if self.llm_generate_completed:
prompt = "<|im_end|>\n<|im_start|>user\n" + cur_contents
else:
prompt = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents
else:
prompt = "<|im_start|>user\n" + cur_contents
self.new_user_msg = False
else:
# subsequent segments of the same turn, directly use content
prompt = cur_contents
add_special_tokens = False
# when first user audio prefill, ensure audio length satisfies FIRST_CHUNK_MS requirements
if is_not_system_prefill and len(audios) > 0 and self.audio_chunk_idx == 0:
assert len(audios) == 1, f"streaming mode only supports single audio, currently {len(audios)}"
first_chunk_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000)
if len(audios[0]) < first_chunk_samples:
pad_len = first_chunk_samples - len(audios[0])
audios[0] = np.concatenate([np.zeros(pad_len, dtype=audios[0].dtype), audios[0]])
model_inputs = self.processor(
[prompt],
[images],
[audios],
max_slice_nums=1 if max_slice_nums is None else max_slice_nums,
use_image_id=False,
chunk_input=True,
return_tensors="pt",
max_length=None,
sampling_rate=16000,
add_special_tokens=add_special_tokens,
online_streaming=is_not_system_prefill,
audio_chunk_idx=self.audio_chunk_idx,
is_last_chunk=is_last_chunk,
).to(self.device)
if len(audios) > 0 and is_not_system_prefill:
self.audio_chunk_idx += 1
# 1. prepare input embeddings
model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs)
# get audio embedding with audio_past_key_values
inputs_embeds = self.get_omni_embedding(
model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=is_not_system_prefill
)
if self.is_first:
self.audio_past_key_values = None
round_id = self._next_round_id
self._pending_round_id = round_id
chunk_type = "system" if msg["role"] == "system" else ("user" if msg["role"] == "user" else "assistant")
seq_len = inputs_embeds.shape[1]
self._enforce_text_window()
cache_length = self._get_kv_cache_length()
attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device)
# 2. do prefill
outputs = self.llm(
past_key_values=self.llm_past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=None,
use_cache=True,
return_dict=True,
)
self.llm_past_key_values = as_dynamic_cache(outputs["past_key_values"])
self._register_chunk(
seq_len,
chunk_type,
round_id=round_id,
input_ids=model_inputs["input_ids"],
tokenizer=self.processor.tokenizer,
)
self._enforce_text_window()
if self.force_rope_reindex:
self._force_reindex_all_cache()
return prompt
@torch.inference_mode()
def streaming_generate(
self,
session_id,
bos_input=None,
generate_audio=True,
audio_token_chunk_size=25, # 25 token/s
tts_sampling_params: TTSSamplingParams = TTSSamplingParams(),
max_new_tokens=256,
enable_thinking=False,
use_tts_template=True,
do_sample=True,
enable_speculative_snapshot=False,
tokenizer=None,
processor=None,
# Teacher forcing (only for the "text → hidden → TTS condition" pipeline in streaming_generate)
# When enabled: instead of letting the LLM auto-regressively generate the text to be spoken,
# it forces the tokens from teacher_forcing_text to be fed in, using the hidden states
# corresponding to these tokens to construct the TTS condition, ensuring the output audio matches the input text.
teacher_forcing: bool = False,
teacher_forcing_text: str = "",
**kwargs,
):
# save speculative snapshot (before modifying any state)
# for VAD speculative snapshot: if speculative snapshot fails, can call restore_speculative_snapshot() to restore
# enable_speculative_snapshot=True when enabled, skip (save some overhead) when disabled
if enable_speculative_snapshot:
self._speculative_snapshot = self.save_speculative_snapshot()
# reset buf
self.new_user_msg = True
self.llm_generated = True
self.llm_generate_completed = False
self.audio_past_key_values = None
self.prepare_processor(processor=processor, tokenizer=tokenizer)
# reset current turn generated token IDs
if hasattr(self, "_streaming_generated_token_ids"):
del self._streaming_generated_token_ids
# reset full generated text
if hasattr(self, "_last_streaming_text"):
del self._last_streaming_text
cache = self._ensure_dynamic_cache()
cache_length = self._get_kv_cache_length(cache)
host_round_id = self._pending_round_id
## in single-turn streaming, each call to streaming_generate needs to reinitialize the streaming_processor, enter the next turn
self.init_streaming_processor()
# 1) llm generate token and hidden states per chunk=10, 2) tts generate audio token chunk per chunk=25, 3) yield 1 chunk audio token
def audio_chunk_generator(
bos_input,
tokenizer,
generate_audio,
tts_sampling_params,
max_new_tokens,
do_sample,
teacher_forcing=False,
teacher_forcing_text="",
**kwargs,
):
generate_chunk_size = 10
if bos_input is None:
bos_input = "".join(
[
"<|im_end|>\n<|im_start|>assistant\n",
"" if enable_thinking else self.think_str.replace("\\n", "\n"),
"<|tts_bos|>" if use_tts_template else "",
]
)
bos_input_ids = tokenizer.encode(bos_input)
bos_input_ids = torch.tensor(bos_input_ids, dtype=torch.long, device=self.device).unsqueeze(0)
bos_input_embeds = self.llm.get_input_embeddings()(bos_input_ids)
generation_inputs_embeds = bos_input_embeds
generated_ids = torch.empty((1, 0), dtype=torch.long, device=self.device)
num_chunks_decode = (max_new_tokens + generate_chunk_size - 1) // generate_chunk_size
conditions = []
# generate chunk by chunk, each chunk has 10 tokens, each chunk takes last hidden states, and pass tokens to tts
llm_streaming_generator = ChunkPrefillChunkGenerate(
model=self.llm,
tokenizer=tokenizer,
terminators=["<|tts_eos|>", "<|im_end|>", ""],
)
if generate_audio:
logits_warpers, logits_processors = gen_logits(
num_code=self.tts.config.num_audio_tokens,
repetition_penalty=tts_sampling_params.repetition_penalty,
top_p=tts_sampling_params.top_p,
top_k=tts_sampling_params.top_k,
)
tts_streaming_generator = TTSStreamingGenerator(
model=self.tts,
temperature=tts_sampling_params.temperature,
eos_token=torch.tensor(
[self.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=self.tts.device,
),
chunk_size=audio_token_chunk_size, # s3tokenizer 1s = 25token
tts_last_turn_tokens=self.tts_last_turn_tokens,
logits_processors=logits_processors,
logits_warpers=logits_warpers,
)
# Teacher forcing branch
# This branch does not rely on ChunkPrefillChunkGenerate's sampling logic, instead:
# 1) First prefill bos_input (assistant + tts_bos) into llm_past_key_values
# 2) Tokenize teacher_forcing_text into token ids
# 3) Feed tokens one by one into the LLM (teacher forcing), obtaining the last_hidden_states for each token
# 4) Use (token_ids, hidden_states) to construct tts condition, then feed it to TTSStreamingGenerator
if teacher_forcing:
# --- 1) prefill bos_input,延续 streaming_prefill 的 KV cache ---
bos_outputs = self.llm(
inputs_embeds=generation_inputs_embeds,
past_key_values=self.llm_past_key_values,
use_cache=True,
output_hidden_states=True,
return_dict=True,
)
self.llm_past_key_values = bos_outputs.past_key_values
if generate_audio:
# Give a length-0 tensor as speaker embedding (no speaker embedding)
spk_emb = torch.empty(
(bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]),
dtype=bos_input_embeds.dtype,
device=bos_input_embeds.device,
)
tts_streaming_generator.spk_emb = spk_emb
# --- 2) tokenize teacher_forcing_text ---
tf_text = teacher_forcing_text or ""
try:
forced_input_ids = tokenizer(tf_text, add_special_tokens=False, return_tensors="pt")["input_ids"]
except Exception:
# Compatible with rare tokenizer return object attributes
forced_input_ids = tokenizer(tf_text, add_special_tokens=False, return_tensors="pt").input_ids
forced_input_ids = forced_input_ids.to(self.device)
total_len = int(forced_input_ids.shape[1])
ptr = 0
# Special case: empty text should also let TTS finish (text_finished=True will automatically concatenate text_eos_embed)
if total_len == 0:
if not generate_audio:
yield forced_input_ids, True
return
empty_tts_embeds = torch.empty(
(1, 0, self.tts.config.hidden_size),
dtype=bos_input_embeds.dtype,
device=self.device,
)
if not hasattr(self, "_streaming_generated_token_ids"):
self._streaming_generated_token_ids = []
tts_generator = tts_streaming_generator.generate_with_buffer(
condition=empty_tts_embeds,
text_finished=True,
)
for audio_token_chunk, is_last_audio_chunk in tts_generator:
yield audio_token_chunk, is_last_audio_chunk
self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
self._last_streaming_text = ""
yield None, None
return
# --- 3) chunk-by-chunk teacher forcing ---
while ptr < total_len:
end = min(ptr + generate_chunk_size, total_len)
chunk_ids = forced_input_ids[:, ptr:end] # [1, chunk_len]
chunk_hidden_list = []
for j in range(chunk_ids.shape[1]):
tok = chunk_ids[:, j : j + 1] # [1, 1]
tok_emb = self.llm.get_input_embeddings()(tok)
out = self.llm(
inputs_embeds=tok_emb,
past_key_values=self.llm_past_key_values,
use_cache=True,
output_hidden_states=True,
return_dict=True,
)
self.llm_past_key_values = out.past_key_values
chunk_hidden_list.append(out.hidden_states[-1]) # [1, 1, hidden]
chunk_hidden = torch.cat(chunk_hidden_list, dim=1) # [1, chunk_len, hidden]
text_finished = end >= total_len
# Save token IDs cache (external eval script will use _last_streaming_text to write generated_text)
if not hasattr(self, "_streaming_generated_token_ids"):
self._streaming_generated_token_ids = []
self._streaming_generated_token_ids.extend(chunk_ids[0].tolist())
if not generate_audio:
yield chunk_ids, text_finished
else:
llm_embeds = self.tts.emb_text(chunk_ids)
hidden_embeds = self.tts.projector_semantic(chunk_hidden)
if self.tts.config.normalize_projected_hidden:
hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
tts_embeds = llm_embeds + hidden_embeds
tts_generator = tts_streaming_generator.generate_with_buffer(
condition=tts_embeds,
text_finished=text_finished,
)
for audio_token_chunk, is_last_audio_chunk in tts_generator:
yield audio_token_chunk, is_last_audio_chunk
ptr = end
if text_finished:
if generate_audio:
self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
break
# Finish: decode this round of text
if hasattr(self, "_streaming_generated_token_ids"):
try:
self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids)
assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text)
self._finalize_round(
round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids
)
except Exception:
self._last_streaming_text = None
else:
self._last_streaming_text = None
# Finally send the end signal
if generate_audio:
yield None, None
else:
return
return
# LLM chunk generate outer loop
for chunk_idx in range(num_chunks_decode):
is_first_generate_chunk = chunk_idx == 0
output = llm_streaming_generator.chunk_generate(
inputs_embeds=generation_inputs_embeds,
past_key_values=self.llm_past_key_values,
is_first_generate_chunk=is_first_generate_chunk,
return_hidden_states=True,
chunk_size=generate_chunk_size + 1 * is_first_generate_chunk,
do_sample=do_sample,
temperature=kwargs.get("temperature", 0.7),
top_p=kwargs.get("top_p", 0.8),
top_k=kwargs.get("top_k", 100),
repetition_penalty=kwargs.get("repetition_penalty", 1.02),
length_penalty=kwargs.get("length_penalty", 1.0),
all_input_ids=generated_ids,
)
if output.chunk_token_ids is None:
break
if is_first_generate_chunk:
if generate_audio:
spk_emb = torch.empty(
(bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]),
dtype=bos_input_embeds.dtype,
device=bos_input_embeds.device,
)
tts_streaming_generator.spk_emb = spk_emb
if output.finished:
yield_chunk_token_ids = output.chunk_token_ids
else:
# the first chunk generated chunk_size + 1 tokens, we only take the first chunk_size tokens,
# the last token is not prefilled, and last hidden states is not obtained
yield_chunk_token_ids = output.chunk_token_ids[:, :-1]
elif output.finished:
yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids], dim=1)
else:
# in the chunk that is not the first chunk, we need to add the token at the end of the previous chunk,
# it is not prefilled into the model to get last hidden states
# similarly, the last generated token of subsequent chunks is not prefilled, and last hidden states is not obtained,
# so it is not passed out
yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids[:, :-1]], dim=1)
if not generate_audio:
chunk_generated_text = tokenizer.decode(yield_chunk_token_ids[0])
yield yield_chunk_token_ids, output.finished
else:
# TTS inner loop
# dense connection here is hardcoded to use text-hidden merged as condition
llm_embeds = self.tts.emb_text(yield_chunk_token_ids)
hidden_embeds = output.last_hidden_states
hidden_embeds = self.tts.projector_semantic(hidden_embeds)
if self.tts.config.normalize_projected_hidden: # default should be opened
hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1)
tts_embeds = llm_embeds + hidden_embeds
conditions.append(tts_embeds)
# Store token IDs instead of decoded text to avoid UTF-8 multi-byte character truncation
if not hasattr(self, "_streaming_generated_token_ids"):
self._streaming_generated_token_ids = []
self._streaming_generated_token_ids.extend(yield_chunk_token_ids[0].tolist())
# there is buffer generated, each time exactly returns 25 audio tokens,
# the last audio chunk returns audio tokens of variable length, length [0, 25]
tts_generator = tts_streaming_generator.generate_with_buffer(
condition=tts_embeds, text_finished=output.finished
)
for audio_token_chunk, is_last_audio_chunk in tts_generator:
yield audio_token_chunk, is_last_audio_chunk
generated_ids = torch.cat([generated_ids, output.chunk_token_ids], dim=1)
generation_inputs_embeds = output.current_inputs_embeds
self.llm_past_key_values = output.past_key_values
if output.finished:
if generate_audio:
self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens
break
# IMPORTANT: Flush remaining TTS buffer when LLM generation ends
# This handles BOTH cases:
# 1. LLM finished with terminator (output.finished=True) - buffer may still have tokens
# 2. LLM hit max chunks limit (output.finished=False) - buffer definitely has tokens
if generate_audio:
if len(tts_streaming_generator._token_buffer) > 0:
batch = torch.cat(tts_streaming_generator._token_buffer, dim=1)
yield batch, True
tts_streaming_generator._token_buffer = []
if generate_audio:
if hasattr(self, "_streaming_generated_token_ids"):
try:
self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids)
assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text)
self._finalize_round(
round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids
)
except Exception:
self._last_streaming_text = None
else:
self._last_streaming_text = None
yield None, None
else:
return
# iter for generating text chunk and audio chunk
audio_chunk_generator_iter = audio_chunk_generator(
bos_input=bos_input,
tokenizer=self.processor.tokenizer,
generate_audio=generate_audio,
tts_sampling_params=tts_sampling_params,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
teacher_forcing=teacher_forcing,
teacher_forcing_text=teacher_forcing_text,
**kwargs,
)
if generate_audio:
if self.tts.config.audio_tokenizer_type == "s3tokenizer_step_audio":
self.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.token2wav_cache["flow_cache_base"])
self.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(
self.token2wav_cache["hift_cache_base"]
)
# pre-insert 3-5 prefix 4218 silence tokens, each token corresponds to 0.04s,
# adding 5 tokens means introducing 0.2s of silence
buffer = [4218] * 3
pre_lookahead = 3
CHUNK_SIZE = 25
chunk_idx = 0
prev_text_len = 0 # track text position for streaming text output
for audio_token_chunk, is_last_audio_chunk in audio_chunk_generator_iter:
if audio_token_chunk is None:
break
buffer += audio_token_chunk.reshape(-1).tolist()
if len(buffer) >= CHUNK_SIZE + pre_lookahead:
waveform_chunk = self.tts.audio_tokenizer.stream(
buffer[: CHUNK_SIZE + pre_lookahead],
prompt_wav=None,
last_chunk=is_last_audio_chunk,
return_waveform=True,
)
waveform_chunk = torch.from_numpy(waveform_chunk)
# get new text chunk corresponding to this waveform
# Decode from accumulated token IDs to avoid UTF-8 multi-byte truncation
new_text = ""
if hasattr(self, "_streaming_generated_token_ids"):
current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids)
# Filter out trailing replacement characters (incomplete UTF-8 sequences)
safe_end = len(current_text)
while safe_end > 0 and current_text[safe_end - 1] == "\ufffd":
safe_end -= 1
safe_text = current_text[:safe_end]
new_text = safe_text[prev_text_len:]
prev_text_len = len(safe_text)
yield waveform_chunk, new_text
buffer = buffer[CHUNK_SIZE:]
chunk_idx += 1
# flush rest
if len(buffer) > 0:
waveform_chunk = self.tts.audio_tokenizer.stream(
buffer,
prompt_wav=None,
last_chunk=True,
return_waveform=True,
)
waveform_chunk = torch.from_numpy(waveform_chunk)
# get remaining new text for the final chunk
# Final chunk: decode all remaining text without filtering
new_text = ""
if hasattr(self, "_streaming_generated_token_ids"):
current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids)
new_text = current_text[prev_text_len:]
prev_text_len = len(current_text)
yield waveform_chunk, new_text
# maybe the buffer is empty, and text is not empty, should we flush text without wave?
else:
raise NotImplementedError(f"not supported audio tokenizer: {self.tts.config.audio_tokenizer_type}")
else:
# For text-only generation, decode tokens and handle partial multi-byte characters
yield from streaming_token_decoder(
audio_chunk_generator_iter,
self.processor.tokenizer,
skip_special_tokens=False,
)
def as_duplex(self, device: Optional[str] = None, **kwargs) -> "MiniCPMODuplex":
"""Convert this MiniCPMO instance to MiniCPMODuplex for full-duplex streaming."""
return MiniCPMODuplex.from_existing_model(
model=self,
device=device,
**kwargs,
)
class MiniCPMODuplex:
"""MiniCPMODuplex model with full-duplex streaming capabilities.
This is a wrapper class that provides duplex streaming functionality.
Use MiniCPMO.as_duplex() to create from an existing model without reloading.
"""
# Default duplex parameters
_default_duplex_params = {
"generate_audio": True,
"ls_mode": "explicit",
"max_new_speak_tokens_per_chunk": 20,
"text_repetition_penalty": 1.05,
"temperature": 0.7,
"top_k": 100,
"top_p": 0.8,
"text_repetition_window_size": 512,
"listen_prob_scale": 1.0,
"force_listen_count": 0,
"tts_temperature": 0.8,
"tts_repetition_penalty": 1.05,
"enable_float16": False,
"n_timesteps": 10,
"chunk_ms": 1000,
"first_chunk_ms": 1035,
"cnn_redundancy_ms": 20,
"sample_rate": 16000,
"sliding_window_mode": "off",
"basic_window_high_tokens": 8000,
"basic_window_low_tokens": 6000,
"context_previous_max_tokens": 500,
"context_max_units": 24,
}
@classmethod
def from_existing_model(
cls,
model: "MiniCPMO",
device: Optional[str] = None,
**kwargs,
) -> "MiniCPMODuplex":
"""Create MiniCPMODuplex from an existing MiniCPMO instance."""
# Create instance without calling __init__
instance = cls.__new__(cls)
instance.name_or_path = getattr(model.config, "_name_or_path", "")
# Get default params helper
def get_param(name):
if name in kwargs:
return kwargs[name]
return cls._default_duplex_params.get(name)
instance.generate_audio = get_param("generate_audio")
instance.ls_mode = get_param("ls_mode")
# Determine device
if device is not None:
instance.device = device
else:
try:
instance.device = str(next(model.parameters()).device)
except StopIteration:
instance.device = "cuda"
# Reuse the existing model - THIS IS THE KEY: no reloading!
instance.model = model
instance.processor = getattr(model, "processor", None)
instance.tokenizer = getattr(instance.processor, "tokenizer", None) if instance.processor else None
if instance.tokenizer is None:
from transformers import AutoTokenizer
instance.tokenizer = AutoTokenizer.from_pretrained(instance.name_or_path, trust_remote_code=True)
if instance.processor is None:
from .processing_minicpmo import MiniCPMOProcessor
instance.processor = MiniCPMOProcessor.from_pretrained(instance.name_or_path, trust_remote_code=True)
instance.processor.tokenizer = instance.tokenizer
# Ensure model has processor reference (same as __init__)
instance.model.processor = instance.processor
# Initialize TTS (same as __init__)
enable_float16 = get_param("enable_float16")
n_timesteps = get_param("n_timesteps")
instance.model.init_tts(enable_float16=enable_float16, n_timesteps=n_timesteps)
instance.break_event = threading.Event()
instance.session_stop_event = threading.Event()
# LLM generation config
instance.max_new_speak_tokens_per_chunk = get_param("max_new_speak_tokens_per_chunk")
instance.text_repetition_penalty = get_param("text_repetition_penalty")
instance.temperature = get_param("temperature")
instance.top_k = get_param("top_k")
instance.top_p = get_param("top_p")
instance.text_repetition_window_size = get_param("text_repetition_window_size")
instance.listen_prob_scale = get_param("listen_prob_scale")
instance.force_listen_count = get_param("force_listen_count")
# TTS generation config
tts_temp_value = get_param("tts_temperature")
instance.tts_temperature = torch.tensor([tts_temp_value], dtype=torch.float, device=instance.device)
instance.tts_repetition_penalty = get_param("tts_repetition_penalty")
# Stream config
instance.CHUNK_MS = get_param("chunk_ms")
instance.FIRST_CHUNK_MS = get_param("first_chunk_ms")
instance.CNN_REDUNDANCY_MS = get_param("cnn_redundancy_ms")
instance.SAMPLE_RATE = get_param("sample_rate")
instance.model.CHUNK_MS = instance.CHUNK_MS
instance.model.FIRST_CHUNK_MS = instance.FIRST_CHUNK_MS
instance.model.CNN_REDUNDANCY_MS = instance.CNN_REDUNDANCY_MS
instance.model.SAMPLE_RATE = instance.SAMPLE_RATE
# Special tokens
instance.unit_token_id = instance.tokenizer.convert_tokens_to_ids("")
instance.image_start_token_id = instance.tokenizer.convert_tokens_to_ids("")
instance.image_end_token_id = instance.tokenizer.convert_tokens_to_ids("")
instance.slice_start_token_id = instance.tokenizer.convert_tokens_to_ids("")
instance.slice_end_token_id = instance.tokenizer.convert_tokens_to_ids("")
instance.listen_token_id = instance.tokenizer.convert_tokens_to_ids("<|listen|>")
instance.speak_token_id = instance.tokenizer.convert_tokens_to_ids("<|speak|>")
instance.tts_bos_token_id = instance.tokenizer.convert_tokens_to_ids("<|tts_bos|>")
instance.tts_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|tts_eos|>")
instance.chunk_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|chunk_eos|>")
instance.chunk_tts_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>")
instance.turn_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|turn_eos|>")
instance.chunk_terminator_token_ids = [
instance.listen_token_id,
instance.chunk_eos_token_id,
instance.chunk_tts_eos_token_id,
]
instance.turn_terminator_token_ids = [instance.turn_eos_token_id]
instance.chunk_speak_token_ids = [instance.speak_token_id]
instance.tts_pad_id = instance.tokenizer.convert_tokens_to_ids("<|tts_pad|>")
bad_token_ids = getattr(instance.tokenizer, "bad_token_ids", [])
instance.forbidden_token_ids = [instance.tts_pad_id] + list(bad_token_ids)
from .utils import StreamDecoder
instance.decoder = StreamDecoder(
llm=instance.model.llm, tokenizer=instance.tokenizer, forbidden_token_ids=instance.forbidden_token_ids
)
# Sliding window config
sliding_window_mode = get_param("sliding_window_mode")
basic_window_high_tokens = get_param("basic_window_high_tokens")
basic_window_low_tokens = get_param("basic_window_low_tokens")
context_previous_max_tokens = get_param("context_previous_max_tokens")
context_max_units = get_param("context_max_units")
instance.decoder.set_window_config(
DuplexWindowConfig(
sliding_window_mode=sliding_window_mode,
basic_window_high_tokens=basic_window_high_tokens,
basic_window_low_tokens=basic_window_low_tokens,
context_previous_max_tokens=context_previous_max_tokens,
context_max_units=context_max_units,
)
)
window_enabled = sliding_window_mode != "off"
instance.decoder.set_window_enabled(window_enabled)
instance.tts_logits_processors = None
instance.tts_eos_token = None
if instance.generate_audio:
instance.tts_logits_processors = gen_logits(
num_code=instance.model.tts.config.num_audio_tokens,
repetition_penalty=instance.tts_repetition_penalty,
)
instance.tts_eos_token = torch.tensor(
[instance.model.tts.config.num_audio_tokens - 1],
dtype=torch.long,
device=instance.device,
)
instance._reset_streaming_state()
return instance
def set_break_event(self):
self.break_event.set()
def clear_break_event(self):
self.break_event.clear()
def set_session_stop(self):
self.session_stop_event.set()
self.break_event.set()
def clear_session_stop(self):
self.session_stop_event.clear()
def is_break_set(self) -> bool:
return self.break_event.is_set()
def is_session_stop_set(self) -> bool:
return self.session_stop_event.is_set()
def _init_token2wav_cache(self, prompt_wav_path: str):
self.model.tts.audio_tokenizer.cache = None
flow_cache, hift_cache = self.model.tts.audio_tokenizer.set_stream_cache(prompt_wav_path)
self.flow_cache_base = torch_clone_recursive(flow_cache)
self.hift_cache_base = torch_clone_recursive(hift_cache)
self.pre_lookahead = int(self.model.tts.audio_tokenizer.flow.pre_lookahead_len)
self.token2wav_initialized = True
def _reset_token2wav_for_new_turn(self):
if self.token2wav_initialized:
self.model.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.flow_cache_base)
self.model.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(self.hift_cache_base)
self.token2wav_buffer = [4218] * 3 # silence token prefix
def _reset_streaming_state(self):
self.audio_chunk_idx = 0
self.current_turn_ended = True
self.speak_count = 0
self.res_ids = []
self.total_ids = []
self.total_hidden = []
# TTS state
self.tts_text_start_pos = 0
self.tts_past_key_values = None
self.tts_current_turn_start_time = None
# token2wav state
self.token2wav_initialized = False
self.token2wav_buffer = []
self.flow_cache_base = None
self.hift_cache_base = None
# Audio prefill state
self.audio_buffer = np.array([], dtype=np.float32)
self.pending_logits: Optional[torch.Tensor] = None
self.current_mode: Optional[str] = None
# Force listen state
self._streaming_generate_count = 0
# Schema tracking: record the complete prefill + generate token sequence
# prefill_schema_tokens: each element is a list of prefill tokens for a unit
# format: [[unit0_prefill_tokens], [unit1_prefill_tokens], ...]
self.prefill_schema_tokens = []
self._current_unit_prefill_tokens = []
def prepare(
self,
prefix_system_prompt: Optional[str] = None,
ref_audio: Optional[np.ndarray] = None,
prompt_wav_path: Optional[str] = None,
context_previous_marker: str = "\n\nprevious: ",
**kwargs,
):
prefix_system_prompt = prefix_system_prompt or "Streaming Omni Conversation."
prefix_system_prompt = "<|im_start|>system\n" + prefix_system_prompt
suffix_system_prompt = "<|im_end|>"
if isinstance(ref_audio, np.ndarray):
prefix_system_prompt += "\n<|audio_start|>"
suffix_system_prompt = "<|audio_end|>" + suffix_system_prompt
self.clear_break_event()
self.clear_session_stop()
self._reset_streaming_state()
self.decoder.reset()
self.model.init_streaming_processor()
if prompt_wav_path is not None and prompt_wav_path and self.generate_audio:
self._init_token2wav_cache(prompt_wav_path)
self._reset_token2wav_for_new_turn()
# Prefill system prompt prefix
if prefix_system_prompt:
tokens = self.tokenizer.encode(prefix_system_prompt, add_special_tokens=False)
for token_id in tokens:
self.decoder.feed(self.decoder.embed_token(token_id))
# Prefill reference audio
if ref_audio is not None:
data = self.processor.process_audio([ref_audio])
embeds_nested = self.model.get_audio_embedding(data, chunk_length=self.model.config.audio_chunk_length)
embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) if embeds_nested else None
if embeds is not None:
self.decoder.feed(embeds)
# register system prompt protection length (protect this part from being removed when sliding window is enabled)
if prefix_system_prompt or suffix_system_prompt or ref_audio is not None:
if self.decoder._window_config.sliding_window_mode == "context":
# Context preserve mode:
# initial layout: [prefix] [suffix] [units...]
# after the first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...]
# register prefix length first, then feed suffix
self._prefix_system_prompt = prefix_system_prompt
self._suffix_system_prompt = suffix_system_prompt
self._ref_audio = ref_audio
suffix_token_ids = []
if suffix_system_prompt:
suffix_token_ids = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False)
# register (when cache only has prefix, no suffix, no previous)
self.decoder.register_system_prompt_with_context(
suffix_token_ids=suffix_token_ids,
context_previous_marker=context_previous_marker, # dynamically added after the first sliding window
)
# now feed suffix
for token_id in suffix_token_ids:
self.decoder.feed(self.decoder.embed_token(token_id))
else:
# non-context preserve mode: first feed suffix, then register total length
if suffix_system_prompt:
tokens = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False)
for token_id in tokens:
self.decoder.feed(self.decoder.embed_token(token_id))
self.decoder.register_system_prompt()
if prefix_system_prompt or suffix_system_prompt:
if ref_audio is not None:
full_prompt = (prefix_system_prompt or "") + "[audio embedding]" + (suffix_system_prompt or "")
else:
full_prompt = (prefix_system_prompt or "") + (suffix_system_prompt or "")
return full_prompt
return ""
@torch.no_grad()
def streaming_prefill(
self,
audio_waveform: Optional[np.ndarray] = None,
frame_list: Optional[list] = None,
text_list: Optional[list] = None,
max_slice_nums: Union[int, List[int]] = 1,
batch_vision_feed: bool = False,
):
"""Streaming prefill - called once per second, processing audio/video data
Args:
audio_waveform: audio waveform data
frame_list: image frame list
text_list: text
max_slice_nums: maximum number of slices for HD image encoding (default 1, no slicing)
Can be an int (same for all images) or a list matching frame_list length
batch_vision_feed: if True, batch all vision embeddings into a single feed call for better performance.
if False (default), feed each embedding individually (original behavior).
Process:
0. determine mode based on input: AUDIO / VISION / OMNI
1. feed token
2. get and feed image embed (if frame_list) - return pending logits in VISION MODE
3. get and feed audio embed (if audio_waveform) - return pending logits in AUDIO/OMNI MODE
Returns:
dict with keys:
- success: bool
- cost_vision_process: float (image processing time)
- cost_vision_embed: float (vision embedding time)
- cost_vision_feed: float (vision feed time)
- cost_audio_process: float (audio processing time)
- cost_audio_embed: float (audio embedding time)
- cost_audio_feed: float (audio feed time)
- cost_all: float (total time)
"""
start_time = time.time()
cost_vision_process = 0.0
cost_vision_embed = 0.0
cost_vision_feed = 0.0
cost_audio_process = 0.0
cost_audio_embed = 0.0
cost_audio_feed = 0.0
def _make_result(success, reasons=""):
reason = reasons
if isinstance(reasons, list):
reason = "; ".join(reasons)
return {
"success": success,
"reason": reason,
"cost_vision_process": cost_vision_process,
"cost_vision_embed": cost_vision_embed,
"cost_vision_feed": cost_vision_feed,
"cost_audio_process": cost_audio_process,
"cost_audio_embed": cost_audio_embed,
"cost_audio_feed": cost_audio_feed,
"cost_all": time.time() - start_time,
}
if self.is_session_stop_set() or self.is_break_set():
return _make_result(False)
has_frames = frame_list is not None and len(frame_list) > 0
has_audio = audio_waveform is not None and len(audio_waveform) > 0
has_text = text_list is not None and len(text_list) > 0
if has_frames and has_audio:
mode = "OMNI"
elif has_frames:
mode = "VISION"
elif has_audio:
mode = "AUDIO"
elif has_text:
mode = "TEXT"
else:
return _make_result(False)
self.pending_logits = None
# sliding window: record unit start position
self.decoder.register_unit_start()
# Schema tracking: start new unit, record prefill tokens
self._current_unit_prefill_tokens = []
# Step 1: Feed token
self.decoder.feed(self.decoder.embed_token(self.unit_token_id))
self._current_unit_prefill_tokens.append(self.unit_token_id)
# Step 2: process image
if has_frames:
t0 = time.time()
# normalize max_slice_nums to a list matching frame_list length
if isinstance(max_slice_nums, int):
max_slice_nums_list = [max_slice_nums] * len(frame_list)
else:
max_slice_nums_list = list(max_slice_nums)
if len(max_slice_nums_list) != len(frame_list):
raise ValueError(
f"max_slice_nums list length ({len(max_slice_nums_list)}) "
f"must match frame_list length ({len(frame_list)})"
)
# check if all max_slice_nums are the same (can use batch processing)
all_same = len(set(max_slice_nums_list)) == 1
if all_same:
# all images use the same max_slice_nums, use batch processing
processed_frames = self.processor.process_image(frame_list, max_slice_nums=max_slice_nums_list[0])
if self.device:
processed_frames = processed_frames.to(self.device)
else:
# different max_slice_nums per image, process individually and merge
all_pixel_values = []
all_tgt_sizes = []
for frame, max_slices in zip(frame_list, max_slice_nums_list):
pf = self.processor.process_image([frame], max_slice_nums=max_slices)
if self.device:
pf = pf.to(self.device)
# pf["pixel_values"][0] is the list of slices for this image
all_pixel_values.extend(pf["pixel_values"][0])
# pf["tgt_sizes"][0] is the array of target sizes for this image's slices
if hasattr(pf["tgt_sizes"][0], "tolist"):
all_tgt_sizes.extend(pf["tgt_sizes"][0].tolist())
else:
all_tgt_sizes.extend(list(pf["tgt_sizes"][0]))
# reconstruct processed_frames with merged data
processed_frames = {
"pixel_values": [all_pixel_values],
"tgt_sizes": [torch.tensor(all_tgt_sizes) if all_tgt_sizes else []],
}
cost_vision_process = time.time() - t0
t0 = time.time()
# get vision embeddings for all images (each may have multiple slices)
# vision_hidden_states is a list, one entry per input image
# each entry contains embeddings for [source_image, slice_1, slice_2, ...]
vision_hidden_states = self.model.get_vision_embedding(processed_frames)
cost_vision_embed = time.time() - t0
if vision_hidden_states is not None and len(vision_hidden_states) > 0:
t0 = time.time()
# vision_hidden_states[0] contains ALL slices from ALL images (flattened)
# shape: [total_slices, 64, D] where total_slices = sum of slices across all images
# we need to know how many slices each image has to correctly group them
# calculate slice counts for each image using get_sliced_grid (lightweight, no actual slicing)
slice_counts = [] # e.g., [5, 9] means img1 has 5 slices (1 source + 4 HD), img2 has 9 slices
for frame_idx, frame in enumerate(frame_list):
max_slices = max_slice_nums_list[frame_idx]
if hasattr(frame, "size"):
# get_sliced_grid returns [M, N] grid or None if no slicing needed
# total images = 1 (source) + M * N (HD slices)
grid = self.processor.image_processor.get_sliced_grid(
frame.size, max_slices, nerver_split=False
)
if grid is not None:
slice_counts.append(1 + grid[0] * grid[1]) # 1 source + M*N slices
else:
slice_counts.append(1) # no slicing, only source image
else:
slice_counts.append(1) # default: single image, no slicing
# get the flattened embeddings tensor
# vision_hidden_states is a list with one element (the batch)
# vision_hidden_states[0] shape: [total_slices, 64, D]
all_embeds = vision_hidden_states[0]
# collect all feed operations first, then execute
# this allows us to identify the last token for VISION mode logits
feed_operations = [] # List of (embed, is_last_for_vision_mode, token_id_or_none)
embed_idx = 0 # current index in all_embeds
for img_idx, num_slices in enumerate(slice_counts):
if num_slices == 0:
continue
# the first embedding is always the source image (downsampled overview)
# Feed token
feed_operations.append(
(self.decoder.embed_token(self.image_start_token_id), False, self.image_start_token_id)
)
# Feed source image embedding (shape: [64, D]) - use None to indicate embedding
feed_operations.append((all_embeds[embed_idx], False, None))
# Feed token
feed_operations.append(
(self.decoder.embed_token(self.image_end_token_id), False, self.image_end_token_id)
)
embed_idx += 1
# remaining embeddings are HD slices (if num_slices > 1)
if num_slices > 1:
for slice_i in range(1, num_slices):
# Feed token
feed_operations.append(
(self.decoder.embed_token(self.slice_start_token_id), False, self.slice_start_token_id)
)
# Feed slice embedding (shape: [64, D])
feed_operations.append((all_embeds[embed_idx], False, None))
# Feed token
feed_operations.append(
(self.decoder.embed_token(self.slice_end_token_id), False, self.slice_end_token_id)
)
embed_idx += 1
# mark the last operation for VISION mode logits
if feed_operations:
feed_operations[-1] = (feed_operations[-1][0], True, feed_operations[-1][2])
# execute feed operations
if batch_vision_feed and feed_operations:
# batch mode: concatenate all embeddings and feed at once
# this reduces LLM forward passes from N to 1
#
# NOTE: batch mode may have slight numerical differences compared to for-loop mode
# due to floating-point precision in attention computation. This is expected behavior
# for causal attention with incremental vs batch computation.
all_embeds_list = []
for embed, is_last, token_id in feed_operations:
# ensure all embeddings have shape [L, H]
if embed.dim() == 1:
embed = embed.unsqueeze(0)
all_embeds_list.append(embed)
# concatenate all embeddings
# torch.cat requires consistent dtype; embeddings should already be same dtype
all_embeds_to_feed = torch.cat(all_embeds_list, dim=0) # [total_L, H]
if mode == "VISION":
# vision mode needs logits from the last token
self.pending_logits, _ = self.decoder.feed(all_embeds_to_feed, return_logits=True)
else:
# omni mode: just feed, wait for audio to get logits
self.decoder.feed(all_embeds_to_feed)
# schema tracking: record all token IDs and embedding markers
for embed, is_last, token_id in feed_operations:
if token_id is not None:
self._current_unit_prefill_tokens.append(token_id)
else:
embed_dim = embed.shape[0] if len(embed.shape) > 1 else 1
self._current_unit_prefill_tokens.append(("img", embed_dim))
else:
for embed, is_last, token_id in feed_operations:
if mode == "VISION" and is_last:
# get logits from the last token
self.pending_logits, _ = self.decoder.feed(embed, return_logits=True)
else:
self.decoder.feed(embed)
# schema tracking: record token ID or embedding marker
if token_id is not None:
self._current_unit_prefill_tokens.append(token_id)
else:
# use tuple to mark image embedding: ("img", dim)
embed_dim = embed.shape[0] if len(embed.shape) > 1 else 1
self._current_unit_prefill_tokens.append(("img", embed_dim))
# for omni mode, no pending logits needed here (wait for audio)
cost_vision_feed = time.time() - t0
# Step 3: process audio (if any)
if has_audio:
# accumulate audio to buffer
self.audio_buffer = np.concatenate([self.audio_buffer, audio_waveform])
# calculate required audio length
if self.audio_chunk_idx == 0:
required_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000)
if len(self.audio_buffer) < required_samples:
padding_samples = required_samples - len(self.audio_buffer)
padding = np.zeros(padding_samples, dtype=np.float32)
self.audio_buffer = np.concatenate([padding, self.audio_buffer])
else:
required_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000)
need_samples = self.processor.get_streaming_chunk_size()
if len(self.audio_buffer) < need_samples:
return _make_result(
False, f"audio not enough: need {need_samples} samples, only {len(self.audio_buffer)}"
)
audio_chunk = self.audio_buffer[:need_samples]
t0 = time.time()
batch_feature = self.processor.process_audio_streaming(
audio_chunk,
reset=False,
return_batch_feature=True,
)
if batch_feature is None or batch_feature.audio_features.shape[-1] == 0:
return _make_result(False, "streaming audio processing returned empty")
# metadata
batch_feature.chunk_idx = self.audio_chunk_idx
batch_feature.use_extra_context = True
batch_feature.prefix_extra_frames = 0 if self.audio_chunk_idx == 0 else 2
batch_feature.suffix_extra_frames = 2
batch_feature = batch_feature.to(self.device)
cost_audio_process = time.time() - t0
t0 = time.time()
embeds_nested = self.model.get_audio_embedding_streaming(
batch_feature,
use_extra_context=batch_feature.use_extra_context,
prefix_extra_frames=batch_feature.prefix_extra_frames,
suffix_extra_frames=batch_feature.suffix_extra_frames,
)
audio_embeds = torch.cat([t for g in embeds_nested for t in g], dim=0)
cost_audio_embed = time.time() - t0
t0 = time.time()
self.pending_logits, _ = self.decoder.feed(audio_embeds, return_logits=True)
cost_audio_feed = time.time() - t0
# schema tracking: use tuple to mark audio embedding: ("audio", dim)
embed_dim = audio_embeds.shape[0] if len(audio_embeds.shape) > 1 else 1
self._current_unit_prefill_tokens.append(("audio", embed_dim))
if self.audio_chunk_idx == 0:
cfg = self.processor._streaming_mel_processor.get_config()
consumed_ms = int(cfg.get("effective_first_chunk_ms", self.FIRST_CHUNK_MS))
consumed_samples = int(consumed_ms * self.SAMPLE_RATE / 1000)
else:
consumed_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000)
self.audio_buffer = self.audio_buffer[consumed_samples:]
self.audio_chunk_idx += 1
# Step 4: process text
if has_text:
# concatenate all text items
text_content = "".join(text_list) if isinstance(text_list, list) else str(text_list)
# tokenize text
text_token_ids = self.tokenizer.encode(text_content, add_special_tokens=False)
if len(text_token_ids) > 0:
# get token embeddings
text_token_ids_tensor = torch.tensor(text_token_ids, dtype=torch.long, device=self.device)
text_embeds = self.decoder.embed_token(text_token_ids_tensor)
# feed to decoder
if mode == "TEXT":
# text-only mode: get logits from the last token
self.pending_logits, _ = self.decoder.feed(text_embeds, return_logits=True)
else:
# mixed mode: just feed, let other modality get logits
self.decoder.feed(text_embeds)
# schema tracking: record text token IDs
for token_id in text_token_ids:
self._current_unit_prefill_tokens.append(token_id)
self.current_mode = mode
if mode == "VISION":
self.audio_chunk_idx += 1
# schema tracking: save current unit's prefill tokens
self.prefill_schema_tokens.append(self._current_unit_prefill_tokens)
return _make_result(True)
@torch.no_grad()
def streaming_generate(
self,
prompt_wav_path=None,
max_new_speak_tokens_per_chunk=20,
decode_mode: str = "sampling",
temperature=0.7,
top_k=100,
top_p=0.8,
listen_prob_scale=1.0,
listen_top_k=None,
text_repetition_penalty=1.05,
text_repetition_window_size=512,
):
start_time = time.time()
if self.is_session_stop_set() or self.is_break_set():
return {
"is_listen": True,
"text": "",
"audio_waveform": self._generate_silence_waveform(),
"end_of_turn": True,
"current_time": self.audio_chunk_idx,
"cost_llm": 0.0,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": 0,
"n_tts_tokens": 0,
}
# check if there are pending logits to process
if not hasattr(self, "pending_logits") or self.pending_logits is None:
return {
"is_listen": True,
"text": "",
"audio_waveform": self._generate_silence_waveform(),
"end_of_turn": False,
"current_time": self.audio_chunk_idx,
"cost_llm": 0.0,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": 0,
"n_tts_tokens": 0,
}
# use pending logits generated in streaming_prefill
logits = self.pending_logits
self.pending_logits = None
# Force listen: check if we should force listen for first N calls
force_listen = self._streaming_generate_count < self.force_listen_count
self._streaming_generate_count += 1
total_hidden_in_unit = []
total_ids_in_unit = []
current_time = self.audio_chunk_idx
is_listen = False
end_of_turn = False
llm_start_time = time.time()
for j in range(max_new_speak_tokens_per_chunk):
if j == max_new_speak_tokens_per_chunk - 1:
if self.ls_mode == "explicit":
self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id))
self.total_ids.append(self.chunk_eos_token_id)
break
if force_listen:
last_id = torch.tensor([self.listen_token_id], dtype=torch.long, device=self.device)
else:
last_id = self.decoder.decode(
logits=logits,
mode=decode_mode,
temperature=temperature,
top_k=top_k,
top_p=top_p,
listen_top_k=listen_top_k,
listen_prob_scale=listen_prob_scale,
text_repetition_penalty=text_repetition_penalty,
text_repetition_window_size=text_repetition_window_size,
)
# if current turn not ended, not allowed to listen (only check when not force_listen)
if last_id.item() == self.listen_token_id and (not self.current_turn_ended):
last_id = torch.tensor([self.tts_bos_token_id], dtype=torch.long, device=self.device)
self.total_ids.append(last_id.item())
is_listen = last_id.item() == self.listen_token_id
# termination condition detection
if last_id.item() in self.chunk_terminator_token_ids:
if self.ls_mode == "explicit":
logits, _ = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
break
else:
# normal speak
self.current_turn_ended = False
if last_id.item() in self.chunk_speak_token_ids:
pass
else:
self.res_ids.append(last_id.item())
self.speak_count += 1
logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True)
assert len(hidden.shape) == 3
assert hidden.shape[0] == 1
assert hidden.shape[1] == 1
end_of_turn = last_id.item() in self.turn_terminator_token_ids
if end_of_turn:
self.current_turn_ended = True
if j != 0:
total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn])
total_ids_in_unit.append(last_id.item())
# Prefill token
unit_end_id = self.tokenizer.convert_tokens_to_ids("")
self.decoder.feed(self.decoder.embed_token(unit_end_id))
self.total_ids.append(unit_end_id)
# calculate generated text (for sliding window context preserve, filter out special tokens)
generated_text = self.tokenizer.decode(total_ids_in_unit, skip_special_tokens=True) if total_ids_in_unit else ""
# sliding window: register unit end, and check if sliding window is needed
input_type = self.current_mode.lower() if self.current_mode else "audio"
self.decoder.register_unit_end(
input_type=input_type,
generated_tokens=total_ids_in_unit,
is_listen=is_listen,
generated_text=generated_text,
)
# select sliding window method based on sliding window mode
if self.decoder._window_config.sliding_window_mode == "context":
self.decoder.enforce_window_with_context()
elif self.decoder._window_config.sliding_window_mode == "basic":
self.decoder.enforce_window()
llm_end_time = time.time()
if is_listen:
self.total_hidden.append([])
return {
"is_listen": True,
"text": "",
"audio_waveform": self._generate_silence_waveform(),
"end_of_turn": False,
"current_time": current_time,
"cost_llm": llm_end_time - llm_start_time,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": len(total_ids_in_unit),
"n_tts_tokens": 0,
}
self.total_hidden.append(total_hidden_in_unit)
text = generated_text # reuse already calculated text
if not self.generate_audio:
return {
"is_listen": False,
"text": text,
"audio_waveform": None,
"end_of_turn": end_of_turn,
"current_time": current_time,
"cost_llm": llm_end_time - llm_start_time,
"cost_tts_prep": 0.0,
"cost_tts": 0.0,
"cost_token2wav": 0.0,
"cost_all": time.time() - start_time,
"n_tokens": len(total_ids_in_unit),
"n_tts_tokens": 0,
}
# TTS generate
tts_start_time = time.time()
tts_prep_start_time = time.time()
tts_condition = self._convert_results_to_tts_input(total_hidden_in_unit)
tts_prep_end_time = time.time()
max_token_per_chunk = 25 + 1
min_token_per_chunk = 25 + 1
if end_of_turn:
min_token_per_chunk = 0
force_flush = False
if self.tts_text_start_pos == 0: # this is the start of the turn
min_token_per_chunk = 0 # allow decoding <1s audio
force_flush = True
if self.tts_current_turn_start_time is None:
self.tts_current_turn_start_time = current_time
new_tokens, old_kv = self.model.tts.generate_chunk(
inputs_embeds=tts_condition,
temperature=self.tts_temperature,
repetition_penalty=self.tts_repetition_penalty,
eos_token=self.tts_eos_token,
force_no_stop=False,
max_new_token=max_token_per_chunk,
min_new_tokens=min_token_per_chunk,
past_key_values=self.tts_past_key_values,
logits_processors=self.tts_logits_processors,
text_start_pos=self.tts_text_start_pos,
)
tts_end_time = time.time()
# update TTS state (note: token2wav reset must be after audio generation, otherwise tokens in buffer will be lost)
if end_of_turn:
self.tts_text_start_pos = 0
self.tts_past_key_values = None
self.tts_current_turn_start_time = None
else:
self.tts_past_key_values = old_kv
self.tts_text_start_pos += tts_condition.shape[1] + new_tokens.shape[1]
# token2wav generation (must be before reset, otherwise tokens in the last but second chunk will be lost)
token2wav_start_time = time.time()
audio_waveform = self._generate_waveform_from_tokens(
new_tokens, prompt_wav_path, end_of_turn, force_flush=force_flush
)
token2wav_end_time = time.time()
# reset token2wav state after audio generation, ensure all tokens in buffer are processed
if end_of_turn:
self._reset_token2wav_for_new_turn()
end_time = time.time()
return {
"is_listen": False,
"text": text,
"audio_waveform": audio_waveform,
"end_of_turn": end_of_turn,
"current_time": current_time,
"cost_llm": llm_end_time - llm_start_time,
"cost_tts_prep": tts_prep_end_time - tts_prep_start_time,
"cost_tts": tts_end_time - tts_start_time,
"cost_token2wav": token2wav_end_time - token2wav_start_time,
"cost_all": end_time - start_time,
"n_tokens": len(total_ids_in_unit),
"n_tts_tokens": new_tokens.numel(),
}
def get_session_schema(self, include_embeddings: bool = True) -> str:
"""get complete schema for current session (includes prefill and generate stages)
Args:
include_embeddings: whether to include embedding placeholders (e.g. [img_embed_64], [audio_embed_50])
Returns:
complete schema string, each unit format:
[img_embed_64][audio_embed_50]<|listen|or|speak|>generated_content
"""
if not hasattr(self, "prefill_schema_tokens") or not hasattr(self, "total_ids"):
return ""
# get token id for splitting generate tokens
unit_end_token_id = self.tokenizer.convert_tokens_to_ids("")
# split generate tokens into each unit
generate_units = []
current_unit = []
for tid in self.total_ids:
current_unit.append(tid)
if tid == unit_end_token_id:
generate_units.append(current_unit)
current_unit = []
# build complete schema
full_schema_parts = []
num_units = max(len(self.prefill_schema_tokens), len(generate_units))
for unit_idx in range(num_units):
unit_schema = ""
# prefill part
if unit_idx < len(self.prefill_schema_tokens):
prefill_tokens = self.prefill_schema_tokens[unit_idx]
for item in prefill_tokens:
if isinstance(item, tuple):
# tuple represents embedding: ("img", dim) or ("audio", dim)
embed_type, embed_dim = item
if include_embeddings:
unit_schema += f"[{embed_type}_embed_{embed_dim}]"
else:
# normal token ID
unit_schema += self.tokenizer.decode([item], skip_special_tokens=False)
# generate part
if unit_idx < len(generate_units):
unit_schema += self.tokenizer.decode(generate_units[unit_idx], skip_special_tokens=False)
full_schema_parts.append(unit_schema)
return "".join(full_schema_parts)
def get_unit_schemas(self, include_embeddings: bool = True) -> list:
"""get list of schema for each unit
Returns:
list of schema strings for each unit
"""
if not hasattr(self, "prefill_schema_tokens") or not hasattr(self, "total_ids"):
return []
unit_end_token_id = self.tokenizer.convert_tokens_to_ids("")
# split generate tokens into each unit
generate_units = []
current_unit = []
for tid in self.total_ids:
current_unit.append(tid)
if tid == unit_end_token_id:
generate_units.append(current_unit)
current_unit = []
# build schema for each unit
unit_schemas = []
num_units = max(len(self.prefill_schema_tokens), len(generate_units))
for unit_idx in range(num_units):
unit_schema = ""
# prefill part
if unit_idx < len(self.prefill_schema_tokens):
prefill_tokens = self.prefill_schema_tokens[unit_idx]
for item in prefill_tokens:
if isinstance(item, tuple):
# tuple represents embedding: ("img", dim) or ("audio", dim)
embed_type, embed_dim = item
if include_embeddings:
unit_schema += f"[{embed_type}_embed_{embed_dim}]"
else:
# normal token ID
unit_schema += self.tokenizer.decode([item], skip_special_tokens=False)
# generate part
if unit_idx < len(generate_units):
unit_schema += self.tokenizer.decode(generate_units[unit_idx], skip_special_tokens=False)
unit_schemas.append(unit_schema)
return unit_schemas
def _convert_results_to_tts_input(self, results):
"""convert LLM hidden states to TTS input"""
if len(results) == 0:
audio_bos = self.model.tts.emb_text(
torch.tensor(
[self.model.tts.audio_bos_token_id],
device=self.model.tts.emb_text.weight.device,
dtype=torch.long,
)
)
return audio_bos.unsqueeze(0)
llm_tokens = []
llm_hidden = []
for hidden in results:
llm_tokens.append(hidden[0])
llm_hidden.append(hidden[1].squeeze(0))
llm_tokens_tensor = torch.Tensor(llm_tokens).to(self.device, dtype=torch.long)
llm_embeds = self.model.tts.emb_text(llm_tokens_tensor)
llm_hidden_tensor = torch.cat(llm_hidden, dim=0)
llm_hidden_tensor = self.model.tts.projector_semantic(llm_hidden_tensor)
llm_hidden_tensor = torch.nn.functional.normalize(llm_hidden_tensor, p=2, dim=-1)
tts_embeds = llm_embeds + llm_hidden_tensor
audio_bos = self.model.tts.emb_text(
torch.tensor(
[self.model.tts.audio_bos_token_id],
device=self.model.tts.emb_text.weight.device,
dtype=torch.long,
)
)
tts_embeds = torch.cat([tts_embeds, audio_bos], dim=0)
return tts_embeds.unsqueeze(0)
def _generate_waveform_from_tokens(
self,
new_tokens: torch.Tensor,
prompt_wav_path: Optional[str],
is_last_chunk: bool = False,
force_flush: bool = False,
) -> Optional[np.ndarray]:
if not self.token2wav_initialized:
logger.warning("token2wav_initialized is uninitialized")
return None
CHUNK_SIZE = 25
token_ids = torch.reshape(new_tokens, (-1,)).tolist()
self.token2wav_buffer += token_ids
has_chunk_eos = any(tid in self.chunk_terminator_token_ids for tid in token_ids)
pcm_bytes_list = []
# process enough tokens
# if there is chunk_eos, try to flush more content
if has_chunk_eos or force_flush:
# when there is chunk_eos, try to flush more content
while len(self.token2wav_buffer) >= self.pre_lookahead + 5: # at least keep some lookahead
chunk_to_process = min(CHUNK_SIZE + self.pre_lookahead, len(self.token2wav_buffer))
pcm_bytes = self.model.tts.audio_tokenizer.stream(
self.token2wav_buffer[:chunk_to_process],
prompt_wav=prompt_wav_path,
)
pcm_bytes_list.append(pcm_bytes)
self.token2wav_buffer = self.token2wav_buffer[min(CHUNK_SIZE, chunk_to_process - self.pre_lookahead) :]
else:
while len(self.token2wav_buffer) >= CHUNK_SIZE + self.pre_lookahead:
pcm_bytes = self.model.tts.audio_tokenizer.stream(
self.token2wav_buffer[: CHUNK_SIZE + self.pre_lookahead],
prompt_wav=prompt_wav_path,
)
pcm_bytes_list.append(pcm_bytes)
self.token2wav_buffer = self.token2wav_buffer[CHUNK_SIZE:]
# if is the last chunk, flush remaining tokens
if is_last_chunk and len(self.token2wav_buffer) > 0:
pcm_bytes = self.model.tts.audio_tokenizer.stream(
self.token2wav_buffer,
prompt_wav=prompt_wav_path,
last_chunk=True,
)
pcm_bytes_list.append(pcm_bytes)
self.token2wav_buffer = []
if not pcm_bytes_list:
return None
# merge PCM and convert to numpy array (24kHz, int16 -> float32)
all_pcm = b"".join(pcm_bytes_list)
if len(all_pcm) == 0:
return None
pcm_np = np.frombuffer(all_pcm, dtype=" np.ndarray:
"""generate silence waveform (24kHz)"""
sample_rate = 24000
num_samples = int(duration_sec * sample_rate)
return np.zeros(num_samples, dtype=np.float32)
def get_generated_text(self) -> str:
return self.tokenizer.decode(self.res_ids)
def get_current_time(self) -> int:
return self.audio_chunk_idx
def as_simplex(self, reset_session: bool = True, reset_token2wav_cache: bool = False) -> "MiniCPMO":
"""Convert this MiniCPMODuplex instance back to MiniCPMO for simplex mode.
Args:
reset_session: If True, reset streaming session state (KV cache, etc.).
Recommended when switching from duplex to simplex mode.
Returns the underlying MiniCPMO model instance without reloading.
"""
if reset_session:
self.model.reset_session(reset_token2wav_cache=reset_token2wav_cache)
return self.model
def get_2d_sincos_pos_embed(embed_dim, image_size):
"""
image_size: image_size or (image_height, image_width)
return:
pos_embed: [image_height, image_width, embed_dim]
"""
if isinstance(image_size, int):
grid_h_size, grid_w_size = image_size, image_size
else:
grid_h_size, grid_w_size = image_size[0], image_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (H, W)
out: (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
given learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (batch_size, num_queries, embed_dim)
"""
def __init__(
self,
num_queries,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
adaptive=False,
max_size=(70, 70),
):
super().__init__()
self.num_queries = num_queries
self.embed_dim = embed_dim
self.num_heads = num_heads
self.adaptive = adaptive
self.max_size = max_size
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim))
self._set_2d_pos_cache(self.max_size)
def _set_2d_pos_cache(self, max_size, device="cpu"):
if is_deepspeed_zero3_enabled():
device = "cuda"
pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device)
self.register_buffer("pos_embed", pos_embed, persistent=False)
def _adjust_pos_cache(self, tgt_sizes, device):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
self._set_2d_pos_cache(self.max_size, device)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, tgt_sizes=None):
assert x.shape[0] == tgt_sizes.shape[0]
bs = x.shape[0]
device = x.device
dtype = x.dtype
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes, device=device)
max_patch_len = torch.max(patch_len)
key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i]
pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D
key_padding_mask[i, patch_len[i] :] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
1, 0, 2
) # BLD => L * B * D
x = self.kv_proj(x) # B * L * D
x = self.ln_kv(x).permute(1, 0, 2) # L * B * D
q = self.ln_q(self.query) # Q * D
out = self.attn(
self._repeat(q, bs), # Q * B * D
x + pos_embed, # L * B * D + L * B * D
x,
key_padding_mask=key_padding_mask,
)[0]
# out: Q * B * D
x = out.permute(1, 0, 2) # B * Q * D
x = self.ln_post(x)
x = x @ self.proj
return x
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class MiniCPMWhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.d_model
try:
# compatible old transformers
from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES
self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
layer_idx=layer_idx,
)
except:
from transformers.models.whisper.modeling_whisper import WhisperAttention
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
config=config,
layer_idx=layer_idx,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = False,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, past_key_values = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
past_key_value=past_key_values,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
if use_cache:
outputs += (past_key_values,)
return outputs
# Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference
class MiniCPMWhisperEncoder(WhisperEncoder):
def __init__(self, config: WhisperConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]
)
def forward(
self,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
past_key_values: Optional[EncoderDecoderCache] = None,
use_cache: Optional[bool] = None,
use_extra_context: Optional[bool] = False,
prefix_extra_frames: Optional[int] = 1,
suffix_extra_frames: Optional[int] = 1,
cnn_min_length: Optional[int] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Ignore copy
input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device)
# Optional: pad short input to minimum length for CNN computation consistency
original_length = input_features.shape[2]
padded_for_cnn = False
if cnn_min_length is not None and original_length < cnn_min_length:
padded_features = torch.zeros(
input_features.shape[0],
input_features.shape[1],
cnn_min_length,
dtype=input_features.dtype,
device=input_features.device,
)
padded_features[:, :, :original_length] = input_features
input_features = padded_features
padded_for_cnn = True
conv1_output = self.conv1(input_features)
inputs_embeds = nn.functional.gelu(conv1_output)
conv2_output = self.conv2(inputs_embeds)
inputs_embeds = nn.functional.gelu(conv2_output)
# If padding was done before, now need to remove the effect of padding
if padded_for_cnn:
# Conv1: stride=1, output length=input length
# Conv2: stride=2, output length=(input length+1)//2
actual_cnn_output_length = (original_length + 1) // 2
inputs_embeds = inputs_embeds[:, :, :actual_cnn_output_length]
# If extra context is used, CNN operations need to remove redundant frames
# conv2 stride=2, so the redundant frames in the input will be halved (upward rounding)
if use_extra_context:
# Input has prefix_extra_frames prefix frames and suffix_extra_frames suffix frames
# conv2 stride=2, output length = ceil(input length / 2)
# For 2 redundant frames, the output is 1 frame (ceil(2/2) = 1)
prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0
suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0
# Remove redundant frames before and after (batch, channels, time)
if prefix_to_remove > 0:
inputs_embeds = inputs_embeds[:, :, prefix_to_remove:]
if 0 < suffix_to_remove < inputs_embeds.shape[2]:
inputs_embeds = inputs_embeds[:, :, :-suffix_to_remove]
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
past_key_values_length = 0
if use_cache:
if past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif isinstance(past_key_values, list):
past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache())
elif isinstance(past_key_values, DynamicCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
else:
pass
past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1])
if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]:
logger.warning("seems the audio is longer than 30s. repeating the last part of the audio")
embed_pos_front = embed_pos[past_key_values_length:, :]
embed_pos = torch.cat(
(
embed_pos_front,
torch.repeat_interleave(
embed_pos[-1, :].unsqueeze(0),
inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length,
dim=0,
),
)
)
else:
embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :]
else:
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
# Ignore copy
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
past_key_values,
use_cache,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
past_key_values=past_key_values,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_encoder_cache = layer_outputs[2 if output_attentions else 1]
else:
next_encoder_cache = None
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
result = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return result
result = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
past_key_values=next_encoder_cache,
)
return result
class MultiModalProjector(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
def forward(self, audio_features):
hidden_states = self.relu(self.linear1(audio_features))
hidden_states = self.linear2(hidden_states)
return hidden_states
class MiniCPMMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.in_dim = config.llm_hidden_size
self.out_dim = config.hidden_size
self.intermediate_size = config.llm_intermediate_size
self.gate_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True)
self.up_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True)
self.down_proj = nn.Linear(self.intermediate_size, self.out_dim, bias=True)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
@dataclass
class MiniCPMTTSGenerationOutput(ModelOutput):
"""
Output class for MiniCPMTTS generation.
Args:
new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq).
audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq).
past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head).
finished (bool): Boolean indicating whether generation is complete.
"""
new_ids: torch.LongTensor = None
audio_input_ids: torch.LongTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
past_input_ids: Optional[torch.LongTensor] = None
finished: bool = None
def make_streaming_chunk_mask_inference(
tts_text_scope: List[int],
tts_text_mask: torch.Tensor,
streaming_audio_chunk_size: int = 50,
dtype: torch.dtype = torch.bfloat16,
device: torch.device = torch.device("cuda"),
max_sequence_length: int = 4096,
):
"""
Example:
Input sequence:
[t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...]
Output 4D causal mask:
------- text positions -------
[0] <- here is [Stts]
[0, 0] <- here is [spk_emb] * N
[0, 0, 0]
[0, 0, 0, 0]
[0, 0, 0, 0, 0]
------- audio positions --------
[0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token
v- here is [Ptts]
[0, 0, -inf, -inf, -inf, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0]
[0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
"""
# Create a complete attention mask for input embeds [batch_size, seq_len], without considering audio mask as audio is always at the end
assert tts_text_mask.dtype == torch.int8
padding_mask = torch.ones(max_sequence_length, dtype=torch.int8, device=device)
padding_mask[tts_text_scope[0] : tts_text_scope[1]] = tts_text_mask
# Initialize a standard upper triangular causal mask
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(max_sequence_length, max_sequence_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if max_sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
else:
raise ValueError("max_sequence_length of tts could not be 1.")
# For each data sample
audio_token_start = tts_text_scope[1]
audio_duration = max_sequence_length - tts_text_scope[1]
# Record which text chunk the current audio chunk can see up to
text_pivot = 0
num_valid_text_tokens = torch.sum(tts_text_mask).item() - 1 # [Ptts] excluded
# How many audio chunks are in total, the num of buckets should be smaller as possible
num_text_tokens_per_audio_chunk = 10
# For each chunk of audio
for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)):
audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size
audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size
# New text seen by this new audio chunk
new_text_this_chunk = num_text_tokens_per_audio_chunk
# The right bound of visible text tokens
text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens)
# Mask all text chunks after the visible ones
# -> [text_pivot, len(tts_text_scope)-1] excluding [Ptts]
causal_mask[
audio_chunk_start - 1 : audio_chunk_end - 1,
# tts_text_scope[0] + text_pivot: tts_text_scope[1],
tts_text_scope[0] + text_pivot : tts_text_scope[1] - 1,
] = min_dtype
# Mask the padding parts in tts_text_masks (no position will attend to it)
causal_mask[:, padding_mask == 0] = min_dtype
# Add extra dimensions, [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len]
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
return causal_mask
class MiniCPMTTS(PreTrainedModel):
config_class = MiniCPMTTSConfig
def __init__(self, config: MiniCPMTTSConfig, audio_tokenizer: None):
super().__init__(config)
self.use_llm_hidden_state = config.use_llm_hidden_state
self.use_text = config.use_text
self.streaming = config.streaming
self.streaming_text_chunk_min = config.streaming_text_chunk_min
self.streaming_text_chunk_max = config.streaming_text_chunk_max
self.streaming_audio_chunk_size = config.streaming_audio_chunk_size
self.streaming_text_reserved_len = config.streaming_text_reserved_len
# streaming tts
self.streaming_text_chunk_size = config.streaming_text_chunk_max
self.audio_bos_token_id = config.audio_bos_token_id
self.num_mel_bins = config.num_mel_bins
self.num_vq = config.num_vq
self.num_audio_tokens = config.num_audio_tokens
self.top_p = config.top_p
self.top_k = config.top_k
self.repetition_penalty = config.repetition_penalty
self.interleaved = config.interleaved
self.attention_type = config.attention_type
self.recomputed_chunks = config.recomputed_chunks
# Two different window size concepts:
# 1. chunk_window_size: number of chunks for sliding_recompute mode (default 2)
# 2. token_window_size: number of tokens for sliding_window mode (default 300)
self.chunk_window_size = config.window_size # chunk-level window for sliding_recompute
self.token_window_size = (
config.streaming_sliding_window_audio_window_size
) # token-level window for sliding_window
# Legacy aliases (for backward compatibility with existing code)
self.window_size = self.chunk_window_size # used in generate_streaming for sliding_recompute
self.sliding_window_size = self.token_window_size # used in TTSStreamingGenerator for sliding_window
if self.attention_type == "sliding_recompute" and self.chunk_window_size <= self.recomputed_chunks:
raise ValueError(
f"sliding_recompute requires chunk_window_size > recomputed_chunks, "
f"but got chunk_window_size={self.chunk_window_size} and recomputed_chunks={self.recomputed_chunks}"
)
if config.backbone_model == "llama":
model_config = LlamaConfig(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
num_hidden_layers=config.num_hidden_layers,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
attn_implementation=config.attn_implementation,
)
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
model = LlamaModel(model_config)
self.model = model
else:
raise ValueError(f"Unsupported backbone model: {config.backbone_model}")
self.projector_spk = self.create_projector(config)
self.projector_semantic = self.create_projector(config)
self.audio_tokenizer = audio_tokenizer
self.emb_code = nn.ModuleList(
[nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)]
)
self.head_code = nn.ModuleList(
[
weight_norm(
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
name="weight",
)
for _ in range(config.num_vq)
]
)
self.condition_type = config.condition_type
return
@staticmethod
def create_projector(config):
if config.projector_type == "mlp":
return MultiModalProjector(config.llm_dim, config.hidden_size)
elif config.projector_type == "minicpm":
return MiniCPMMLP(config)
elif config.projector_type == "default":
return nn.Linear(config.llm_dim, config.hidden_size, bias=False)
else:
raise ValueError(f"Unsupported projector type: {config.projector_type}")
# non-streaming
@torch.inference_mode()
def generate(
self,
inputs_embeds: torch.Tensor,
eos_token: Union[int, torch.Tensor],
force_no_stop=False,
min_new_token=50,
max_new_token=2048,
show_tqdm=True,
streaming=False,
text_lengths=None,
sampling_params: TTSSamplingParams = TTSSamplingParams(),
):
temperature = torch.tensor(
[sampling_params.temperature] * self.config.num_vq,
dtype=torch.float,
device=self.device,
)
temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
inputs_embeds.device
)
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens,
repetition_penalty=sampling_params.repetition_penalty,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
)
# We only support batch size `1` for now
assert inputs_embeds.shape[0] == 1
eos_token = eos_token.to(inputs_embeds.device)
finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
condition_length = inputs_embeds.shape[1]
pbar: Optional[tqdm] = None
if show_tqdm:
pbar = tqdm(
total=max_new_token,
desc="code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
)
if streaming:
raise NotImplementedError("this kind of streaming is not supported yet")
new_tokens = torch.zeros(
inputs_embeds.shape[0],
max_new_token,
self.num_vq,
device=inputs_embeds.device,
dtype=torch.long,
)
past_key_values = None
for t in range(max_new_token):
audio_bos = False
# If this is the first audio token, the case is special
if t == 0:
audio_bos = True
inputs_embeds = inputs_embeds
position_ids = torch.tensor(
list(range(0, condition_length)),
dtype=torch.long,
device=self.device,
).unsqueeze(0)
if streaming:
raise NotImplementedError("this kind of streaming is not supported yet")
else:
causal_mask_4d = None
else:
code_emb = []
for q in range(self.num_vq):
x = self.emb_code[q](new_tokens[:, t - 1 : t, q])
code_emb.append(x)
inputs_embeds = torch.stack(code_emb, 3).sum(3)
position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze(
0
)
if streaming:
raise NotImplementedError("this kind of streaming is not supported yet")
else:
causal_mask_4d = None
if self.config.backbone_model == "llama":
outputs: BaseModelOutputWithPast = self.model(
position_ids=position_ids,
cache_position=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=causal_mask_4d,
use_cache=True,
output_attentions=False,
# return_dict=True, # Add this to ensure returns dict with past_key_values
)
else:
raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}")
del position_ids
del inputs_embeds
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= temperature
if not audio_bos:
input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
logits_token = input_ids_sliced.reshape(
input_ids_sliced.size(0) * input_ids_sliced.size(1),
-1,
).to(self.device)
del input_ids_sliced
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
for logitsWarpers in logits_warpers:
logits = logitsWarpers(logits_token, logits)
del logits_token
if t < min_new_token:
logits[:, eos_token] = -torch.inf
if force_no_stop:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
del scores
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
new_tokens[:, t] = idx_next
if t == 0 and finish.any():
break
del idx_next
if finish.all():
break
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
if not finish.all():
logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
genrated_input_ids = new_tokens[:, 0:t, :]
return MiniCPMTTSGenerationOutput(
new_ids=genrated_input_ids,
audio_input_ids=None, # for update purpose
past_key_values=None, # for update purpose
past_input_ids=None, # for update purpose
finished=finish.all(),
)
# fake streaming
@torch.inference_mode()
def generate_mock_legacy_streaming(
self,
inputs_embeds: torch.Tensor,
eos_token: Union[int, torch.Tensor],
force_no_stop=False,
min_new_token=50,
max_new_token=2048,
show_tqdm=True,
streaming=False,
text_lengths=None,
sampling_params: TTSSamplingParams = TTSSamplingParams(),
valid_text_length=None,
):
assert valid_text_length is not None, "valid_text_length should be not None"
tts_text_scope = [0, inputs_embeds.shape[1]]
tts_text_mask = torch.zeros(inputs_embeds.shape[1], dtype=torch.int8, device=inputs_embeds.device)
tts_text_mask[0:valid_text_length] = 1
tts_text_mask[-1] = 1 # [Ptts]
streaming_mask_4d_full = make_streaming_chunk_mask_inference(
tts_text_scope=tts_text_scope,
tts_text_mask=tts_text_mask,
dtype=torch.bfloat16,
device=self.device,
streaming_audio_chunk_size=50,
max_sequence_length=4096,
)
temperature = torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.device)
temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
inputs_embeds.device
)
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens,
repetition_penalty=sampling_params.repetition_penalty,
top_p=sampling_params.top_p,
top_k=sampling_params.top_k,
)
# We only support batch size `1` for now
assert inputs_embeds.shape[0] == 1
eos_token = eos_token.to(inputs_embeds.device)
finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
condition_length = inputs_embeds.shape[1]
pbar: Optional[tqdm] = None
if show_tqdm:
pbar = tqdm(
total=max_new_token,
desc="code",
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
)
new_tokens = torch.zeros(
inputs_embeds.shape[0],
max_new_token,
self.num_vq,
device=inputs_embeds.device,
dtype=torch.long,
)
past_key_values = None
for t in range(max_new_token):
audio_bos = False
if t == 0:
audio_bos = True
inputs_embeds = inputs_embeds
position_ids = torch.tensor(
list(range(0, condition_length)),
dtype=torch.long,
device=self.device,
).unsqueeze(0)
causal_mask_4d = streaming_mask_4d_full[:, :, :condition_length, :condition_length]
else:
code_emb = []
for q in range(self.num_vq):
x = self.emb_code[q](new_tokens[:, t - 1 : t, q])
code_emb.append(x)
inputs_embeds = torch.stack(code_emb, 3).sum(3)
position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze(
0
)
causal_mask_4d = streaming_mask_4d_full[
:,
:,
condition_length + t : condition_length + t + 1,
: condition_length + t,
]
# get length of past_key_values
past_key_values_length = past_key_values[0][0].shape[2]
assert causal_mask_4d.shape[-1] == (past_key_values_length + 1)
if self.config.backbone_model == "llama":
outputs: BaseModelOutputWithPast = self.model(
position_ids=position_ids,
cache_position=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=causal_mask_4d,
use_cache=True,
output_attentions=False,
# return_dict=True, # Add this to ensure returns dict with past_key_values
)
else:
raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}")
del position_ids
del inputs_embeds
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= temperature
if not audio_bos:
input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
logits_token = input_ids_sliced.reshape(
input_ids_sliced.size(0) * input_ids_sliced.size(1),
-1,
).to(self.device)
del input_ids_sliced
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
for logitsWarpers in logits_warpers:
logits = logitsWarpers(logits_token, logits)
del logits_token
if t < min_new_token:
logits[:, eos_token] = -torch.inf
if force_no_stop:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
del scores
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
new_tokens[:, t] = idx_next
if t == 0 and finish.any():
break
del idx_next
if finish.all():
break
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.close()
if not finish.all():
logger.warning(f"incomplete result. hit max_new_token: {max_new_token}")
genrated_input_ids = new_tokens[:, 0:t, :]
return MiniCPMTTSGenerationOutput(
new_ids=genrated_input_ids,
audio_input_ids=None, # for update purpose
past_key_values=None, # for update purpose
past_input_ids=None, # for update purpose
finished=finish.all(),
)
# non-streaming, interleave
@torch.inference_mode()
def generate_chunk(
self,
inputs_embeds: torch.Tensor,
temperature: torch.Tensor,
repetition_penalty: float,
eos_token: Union[int, torch.Tensor],
force_no_stop=False,
max_new_token=500,
min_new_tokens=0,
past_key_values=None,
logits_processors=None,
text_start_pos=None,
):
"""For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like:
|Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS|
where the last position is the audio BOS token.
So, the first iteration in generation directly forward the model with inputs_embeds, and
the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token.
"""
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty
)
# We only support batch size `1` for now
assert inputs_embeds.shape[0] == 1
eos_token = eos_token.to(inputs_embeds.device)
finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool()
temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to(
inputs_embeds.device
)
condition_length = inputs_embeds.shape[1]
new_tokens = torch.zeros(
inputs_embeds.shape[0],
max_new_token,
self.num_vq,
device=inputs_embeds.device,
dtype=torch.long,
)
for t in range(max_new_token):
audio_bos = False
# If this is the first audio token, the case is special
if t == 0:
audio_bos = True
inputs_embeds_ = inputs_embeds
position_ids = torch.tensor(
list(range(text_start_pos, text_start_pos + condition_length)),
dtype=torch.long,
device=self.device,
).unsqueeze(0)
else:
# Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate`
inputs_embeds_ = self.emb_code[0](new_tokens[:, t - 1 : t, 0])
position_ids = torch.tensor(
[text_start_pos + condition_length + t - 1], # prefill the previous token
dtype=torch.long,
device=self.device,
).unsqueeze(0)
outputs: BaseModelOutputWithPast = self.model(
position_ids=position_ids,
# cache_position=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds_,
use_cache=True,
output_attentions=False,
# return_dict=True, # Add this to ensure returns dict with past_key_values
)
del position_ids
del inputs_embeds_
hidden_states = outputs.last_hidden_state
past_key_values = outputs.past_key_values
with P.cached():
logits = torch.empty(
hidden_states.size(0),
hidden_states.size(1),
self.num_audio_tokens,
self.num_vq,
dtype=torch.float,
device=self.device,
)
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter] = x
del x
del hidden_states
logits = logits[:, -1].float()
logits = logits.permute(0, 2, 1)
logits = logits.reshape(-1, logits.size(2))
logits /= temperature
if not audio_bos:
input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens
logits_token = input_ids_sliced.reshape(
input_ids_sliced.size(0) * input_ids_sliced.size(1),
-1,
).to(self.device)
del input_ids_sliced
for logitsProcessors in logits_processors:
logits = logitsProcessors(logits_token, logits)
del logits_token
if force_no_stop or t < min_new_tokens:
logits[:, eos_token] = -torch.inf
scores = F.softmax(logits, dim=-1)
del logits
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
del scores
idx_next = idx_next.view(-1, self.num_vq)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
new_tokens[:, t] = idx_next
if t == 0 and finish.any():
break
del idx_next
if finish.all():
break
# The latest generated token is not in the range returned this time. If it is an eos token, it is not returned. If it is a normal token, it is not returned.
genrated_input_ids = new_tokens[:, 0:t, :]
return genrated_input_ids, past_key_values
@torch.inference_mode()
def interleaved_generate(
self,
spk_embeds: torch.Tensor,
conditions: List[torch.Tensor],
temperature: torch.Tensor,
repetition_penalty: float,
eos_token: Union[int, torch.Tensor],
**kwargs,
):
"""
For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like:
|Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS|
where the last position is the audio BOS token.
So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token.
"""
temperature = torch.tensor([temperature], dtype=torch.float, device=self.device)
logits_warpers, logits_processors = gen_logits(
num_code=self.config.num_audio_tokens,
repetition_penalty=repetition_penalty,
)
eos_token = eos_token.to(conditions[0].device)
num_chunks = len(conditions)
text_start_pos = 0
last_window_size = 0
past_key_values = None
for idx in range(num_chunks):
condition = conditions[idx].to(conditions[0].device)
if self.attention_type == "sliding_recompute":
recomputed_conditions = []
if (
idx >= self.window_size
and (idx - self.recomputed_chunks) % (self.window_size - self.recomputed_chunks) == 0
):
for i in range(self.recomputed_chunks):
recomputed_conditions.append(conditions[idx - self.recomputed_chunks + i])
recomputed_conditions.append(
self.emb_code[0](generated_tokens[-self.recomputed_chunks + i][:, :, 0])
)
recomputed_conditions.append(condition)
condition = torch.cat(recomputed_conditions, dim=1)
text_start_pos = 0
new_tokens, old_kv = self.generate_chunk(
inputs_embeds=condition,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token=eos_token,
force_no_stop=False,
max_new_token=500,
past_key_values=None,
logits_processors=logits_processors,
text_start_pos=text_start_pos,
)
else:
new_tokens, old_kv = self.generate_chunk(
inputs_embeds=condition,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token=eos_token,
force_no_stop=False,
max_new_token=500,
past_key_values=past_key_values,
logits_processors=logits_processors,
text_start_pos=text_start_pos,
)
else:
new_tokens, old_kv = self.generate_chunk(
inputs_embeds=condition,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token=eos_token,
force_no_stop=False,
max_new_token=500,
past_key_values=past_key_values,
logits_processors=logits_processors,
text_start_pos=text_start_pos,
)
past_key_values = []
if self.attention_type == "sliding_window" and idx >= 1:
for layer_idx in range(len(old_kv)):
past_key_values.append(
(
old_kv[layer_idx][0][:, :, last_window_size:, :],
old_kv[layer_idx][1][:, :, last_window_size:, :],
)
)
else:
past_key_values = old_kv
last_window_size = condition.shape[1] + new_tokens.shape[1]
text_start_pos += last_window_size
if idx == 0:
generated_tokens = [new_tokens]
else:
generated_tokens.append(new_tokens)
return MiniCPMTTSGenerationOutput(new_ids=torch.cat(generated_tokens, dim=1), finished=True)
class CustomRepetitionPenaltyLogitsProcessorRepeat:
def __init__(self, penalty: float, max_input_ids: int, past_window: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.max_input_ids = max_input_ids
self.past_window = past_window
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.size(1) > self.past_window:
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
if freq.size(0) > self.max_input_ids:
freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_()
alpha = torch.pow(self.penalty, freq)
scores = scores.contiguous()
inp = scores.multiply(alpha)
oth = scores.divide(alpha)
con = scores < 0
out = torch.where(con, inp, oth)
del inp, oth, scores, con, alpha
return out
def gen_logits(num_code: int, top_p=0.7, top_k=20, repetition_penalty=1.0):
logits_warpers = []
if top_p is not None:
logits_warpers.append(TopPLogitsWarper(top_p, min_tokens_to_keep=3))
if top_k is not None:
logits_warpers.append(TopKLogitsWarper(top_k, min_tokens_to_keep=3))
logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16))
return logits_warpers, logits_processors
# Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for positionidspositionidsposition_ids.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
dtype = self.lm_head.weight.dtype
min_dtype = torch.finfo(dtype).min
from transformers.models.paligemma.modeling_paligemma import (
_prepare_4d_causal_attention_mask_with_cache_position,
)
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=batch_size,
)
model_inputs.update(
{
"position_ids": position_ids,
# "cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs