#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 2026 The OpenBMB Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import math import os import tempfile import threading import time import types from copy import deepcopy from dataclasses import dataclass from functools import partial from threading import Thread from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import Union import numpy as np import torch import torch.nn.functional as F import torch.nn.utils.parametrize as P from torch import nn from torch.nn.init import trunc_normal_ from torch.nn.utils.parametrizations import weight_norm from tqdm import tqdm if os.getenv("USE_FLAGOS") == "1": import importlib flag_gems = importlib.import_module("flag_gems") # noqa: F401 flag_gems_experimental = importlib.import_module("flag_gems.experimental_ops") gems_rmsnorm = flag_gems_experimental.rmsnorm class GemsRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): return gems_rmsnorm(hidden_states, self.weight, self.variance_epsilon) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" from transformers.models.llama import modeling_llama from transformers.models.qwen3 import modeling_qwen3 modeling_qwen3.Qwen3RMSNorm = GemsRMSNorm modeling_llama.LlamaRMSNorm = GemsRMSNorm from transformers import LlamaConfig from transformers import LlamaModel from transformers import PreTrainedModel from transformers import Qwen3ForCausalLM from transformers import Qwen3PreTrainedModel from transformers import TextIteratorStreamer from transformers.activations import ACT2FN from transformers.cache_utils import Cache from transformers.cache_utils import DynamicCache from transformers.cache_utils import EncoderDecoderCache from transformers.cache_utils import StaticCache from transformers.generation.logits_process import TopKLogitsWarper from transformers.generation.logits_process import TopPLogitsWarper from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import ModelOutput from transformers.models.whisper.configuration_whisper import WhisperConfig from transformers.models.whisper.modeling_whisper import WhisperEncoder from .configuration_minicpmo import MiniCPMOConfig from .configuration_minicpmo import MiniCPMTTSConfig from .modeling_navit_siglip import SiglipVisionTransformer from .processing_minicpmo import MiniCPMOProcessor from .utils import as_dynamic_cache from .utils import ChunkPrefillChunkGenerate from .utils import drop_tokens_from_cache from .utils import DuplexWindowConfig from .utils import get_kv_cache_length from .utils import normalize_content from .utils import realign_rotary_suffix from .utils import SpeculativeSnapshot from .utils import streaming_token_decoder from .utils import StreamingWindowConfig from .utils import torch_clone_recursive from .utils import TTSSamplingParams from .utils import TTSStreamingGenerator logger = logging.getLogger(__name__) class MiniCPMOPreTrainedModel(Qwen3PreTrainedModel): config_class = MiniCPMOConfig class MiniCPMO(MiniCPMOPreTrainedModel): def __init__(self, config): super().__init__(config) self.llm = Qwen3ForCausalLM(config) self.embed_dim = self.llm.config.hidden_size self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm) # patch llm # init vision module if self.config.init_vision: self.vpm = self.init_vision_module() self.vision_dim = self.vpm.embed_dim self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) # init audio module if self.config.init_audio: self.apm = self.init_audio_module() audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4) self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step) self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim) self.audio_encoder_layer = -1 # init tts module if self.config.init_tts: self.tts = self.init_tts_module() self.terminators = ["<|im_end|>", "<|endoftext|>"] self.think_str = "" if self.llm.__class__.__name__ == "Qwen3ForCausalLM": self.think_str = "\\n\\n\\n\\n" # for streaming self.reset_session(reset_token2wav_cache=True) # streaming audio processing constants self.SAMPLE_RATE = 16000 self.CHUNK_MS = 1000 # regular chunk length (ms) self.FIRST_CHUNK_MS = 1035 # first chunk length (ms) self.CNN_REDUNDANCY_MS = 0 # CNN redundancy (ms) # for sliding window self.streaming_window_config = StreamingWindowConfig() self.streaming_require_system_prompt = True self.streaming_window_enabled = True self.force_rope_reindex = False # RoPE reindex testing switch def init_streaming_processor(self): self.prepare_processor(processor=None, tokenizer=None) if hasattr(self.processor, "set_streaming_mode"): self.processor.set_streaming_mode( mode="exact", chunk_ms=self.CHUNK_MS, first_chunk_ms=self.FIRST_CHUNK_MS, cnn_redundancy_ms=self.CNN_REDUNDANCY_MS, enable_sliding_window=True, slide_trigger_seconds=30.0, slide_stride_seconds=10.0, ) self.processor.reset_streaming() self.audio_chunk_idx = 0 def reset_session(self, reset_token2wav_cache=True): self.llm_past_key_values = None self.audio_past_key_values = None self.tts_last_turn_tokens = None self.llm_generated = False # last turn generated by llm or not self.llm_generate_completed = False self.new_user_msg = True self.session_id = None if reset_token2wav_cache: self.token2wav_cache = None # for sliding window self.streaming_text_preserve = 0 self.streaming_position_offset = 0 self._rope_inv_freq_cache: Dict[Tuple[int, torch.device], torch.Tensor] = {} self._next_round_id = 0 self._pending_round_id = None self._omni_chunk_history: List[Dict[str, Union[str, int]]] = [] self._round_history: List[Dict[str, Union[int, str, torch.Tensor, Optional[int]]]] = [] def init_vision_module(self): if self.config._attn_implementation == "flash_attention_2": self.config.vision_config._attn_implementation = "flash_attention_2" else: self.config.vision_config._attn_implementation = "eager" model = SiglipVisionTransformer(self.config.vision_config) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] setattr(model, "embed_dim", model.embeddings.embed_dim) setattr(model, "patch_size", model.embeddings.patch_size) return model def init_resampler(self, embed_dim, vision_dim): return Resampler( num_queries=self.config.query_num, embed_dim=embed_dim, num_heads=embed_dim // 128, kv_dim=vision_dim, adaptive=True, ) def init_audio_module(self): if self.config._attn_implementation == "eager": self.config.audio_config._attn_implementation = "eager" else: # using flash_attention_2 will cause: RuntimeError: cu_seqlens_q must have shape (batch_size + 1) self.config.audio_config._attn_implementation = "sdpa" return MiniCPMWhisperEncoder(self.config.audio_config) def init_tts_module(self): if self.config._attn_implementation == "flash_attention_2": self.config.tts_config.attn_implementation = "flash_attention_2" else: self.config.tts_config.attn_implementation = "eager" return MiniCPMTTS(config=self.config.tts_config, audio_tokenizer=None) def _ensure_asset_dir(self, asset_subpath: str, model_dir: Optional[str] = None) -> str: """Ensure asset directory exists, downloading from HF if needed.""" model_dir = model_dir or os.path.join(self.config._name_or_path, asset_subpath) if not os.path.exists(model_dir): from huggingface_hub import snapshot_download repo_dir = snapshot_download( repo_id="openbmb/MiniCPM-o-4_5", allow_patterns=[f"{asset_subpath}/**"], ) model_dir = os.path.join(repo_dir, asset_subpath) assert os.path.exists(model_dir), f"Asset directory not found: {model_dir}" return model_dir def init_tts(self, model_dir=None, enable_float16=False, n_timesteps=10, **kwargs): if self.config.tts_config.audio_tokenizer_type != "s3tokenizer_step_audio": logger.warning("audio tokenizer type is set to s3tokenizer_step_audio") self.tts.config.audio_tokenizer_type = "s3tokenizer_step_audio" try: from stepaudio2 import Token2wav except ImportError: raise ImportError("Please install Token2wav via: pip install minicpmo-utils[all]") model_dir = self._ensure_asset_dir("assets/token2wav", model_dir) self.tts.audio_tokenizer = Token2wav(model_dir, float16=enable_float16, n_timesteps=n_timesteps) return self.tts.audio_tokenizer def get_input_embeddings(self): return self.llm.get_input_embeddings() def set_input_embeddings(self, value): self.llm.embed_tokens = value def get_output_embeddings(self): return self.llm.lm_head def set_output_embeddings(self, new_embeddings): self.llm.lm_head = new_embeddings def set_decoder(self, decoder): self.llm = decoder def get_decoder(self): return self.llm @staticmethod def get_sys_prompt(ref_audio=None, mode="default", language="en", ref_audio_max_ms=None): if ref_audio is not None: if isinstance(ref_audio, str): import os import librosa if os.path.isfile(ref_audio): duration = ref_audio_max_ms / 1000.0 if ref_audio_max_ms else None ref_audio, _ = librosa.load(ref_audio, sr=16000, mono=True, duration=duration) else: logger.error(f"Could not find {ref_audio}") ref_audio = None assert isinstance(ref_audio, np.ndarray), "ref_audio error" if mode == "omni": if language == "zh": sys_prompt = "" vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。" vc_prompt_suffix = ( "请用这种声音风格来为用户提供帮助。 请认真、高质量地回复用户的问题。 请用高自然度的方式和用户聊天。" ) else: sys_prompt = "" vc_prompt_prefix = sys_prompt + "Clone the voice in the provided audio prompt." vc_prompt_suffix = "As an assistant, you will speak using this voice style." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} else: sys_msgs = {"role": "system", "content": [sys_prompt]} return sys_msgs elif mode == "audio_assistant": if language == "zh": vc_prompt_prefix = "模仿音频样本的音色并生成新的内容。" vc_prompt_suffix = "你的任务是用这种声音模式来当一个助手。请认真、高质量地回复用户的问题。请用高自然度的方式和用户聊天。你是由面壁智能开发的人工智能助手:面壁小钢炮。" else: vc_prompt_prefix = "Clone the voice in the provided audio prompt." vc_prompt_suffix = "Please assist users while maintaining this voice style. Please answer the user's questions seriously and in a high quality. Please chat with the user in a highly human-like and oral style. You are a helpful assistant developed by ModelBest: MiniCPM-Omni." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} else: logger.warning( "Warning: ref_audio is None, speech generation will be performed based on the default voice." ) sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]} return sys_msgs elif mode == "audio_roleplay": if language == "zh": vc_prompt_prefix = "模仿输入音频中的声音特征。" vc_prompt_suffix = "假装你是上述音频中的人物,与我进行对话。" else: vc_prompt_prefix = "Clone the voice in the provided audio prompt." vc_prompt_suffix = "Try to role-play the character based on the audio prompt above." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio, vc_prompt_suffix]} else: sys_msgs = {"role": "system", "content": ["Use the voice.", vc_prompt_suffix]} return sys_msgs elif mode == "voice_cloning": if language == "zh": vc_prompt_prefix = "模仿输入音频中的声音特征。" else: vc_prompt_prefix = "Clone the voice in the provided audio prompt." if ref_audio is not None: sys_msgs = {"role": "system", "content": [vc_prompt_prefix, ref_audio]} else: raise ValueError("ref_audio con't be None in voice_cloning mode.") return sys_msgs else: sys_prompt = "You are a helpful assistant. You can accept audio and text input and output voice and text." sys_msgs = {"role": "system", "content": [sys_prompt]} return sys_msgs @staticmethod def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = torch.device("cpu"), num_lookhead: int = 0, ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk >=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device num_lookhead: Returns: torch.Tensor: mask """ ret = torch.zeros(size, size, device=device, dtype=torch.bool) for i in range(size): if num_left_chunks < 0: start = 0 else: start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) ending = min((i // chunk_size + 1) * chunk_size + num_lookhead, size) ret[i, start:ending] = True return ret def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """Computes the output length of the convolutional layers and the output length of the audio encoder""" input_lengths_after_cnn = (input_lengths - 1) // 2 + 1 input_lengths_after_pooling = ( input_lengths_after_cnn - self.config.audio_pool_step ) // self.config.audio_pool_step + 1 input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32) return input_lengths_after_cnn, input_lengths_after_pooling def get_vision_embedding(self, data): if "vision_hidden_states" not in data: dtype = self.llm.model.embed_tokens.weight.dtype device = self.llm.model.embed_tokens.weight.device tgt_sizes = data["tgt_sizes"] pixel_values_list = data["pixel_values"] vision_hidden_states = [] all_pixel_values = [] img_cnt = [] for pixel_values in pixel_values_list: img_cnt.append(len(pixel_values)) all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values]) # exist image if all_pixel_values: tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)] tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32) max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1]) all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values, batch_first=True, padding_value=0.0 ) B, L, _ = all_pixel_values.shape all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device) for i in range(B): patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True vision_batch_size = self.config.vision_batch_size all_pixel_values = all_pixel_values.type(dtype) if B > vision_batch_size: hs = [] for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size tmp_hs = self.vpm( all_pixel_values[start_idx:end_idx], patch_attention_mask=patch_attn_mask[start_idx:end_idx], tgt_sizes=tgt_sizes[start_idx:end_idx], ).last_hidden_state hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: vision_embedding = self.vpm( all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, ).last_hidden_state vision_embedding = self.resampler(vision_embedding, tgt_sizes) start = 0 for pixel_values in pixel_values_list: img_cnt = len(pixel_values) if img_cnt > 0: vision_hidden_states.append(vision_embedding[start : start + img_cnt]) start += img_cnt else: vision_hidden_states.append([]) else: # no image if self.training: dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype) tgt_sizes = torch.Tensor( [ [ (224 // self.config.patch_size), math.ceil(224 / self.config.patch_size), ] ] ).type(torch.int32) dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes) else: dummy_feature = [] for _ in range(len(pixel_values_list)): vision_hidden_states.append(dummy_feature) else: vision_hidden_states = data["vision_hidden_states"] return vision_hidden_states def get_vllm_embedding(self, data): vision_hidden_states = self.get_vision_embedding(data) if hasattr(self.llm.config, "scale_emb"): vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb else: vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) vision_hidden_states = [ i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states ] bs = len(data["input_ids"]) for i in range(bs): cur_vs_hs = vision_hidden_states[i] if len(cur_vs_hs) > 0: cur_vllm_emb = vllm_embedding[i] cur_image_bound = data["image_bound"][i] if len(cur_image_bound) > 0: image_indices = torch.stack( [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] ).to(vllm_embedding.device) cur_vllm_emb.scatter_( 0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), cur_vs_hs.view(-1, cur_vs_hs.shape[-1]), ) elif self.training: cur_vllm_emb += cur_vs_hs[0].mean() * 0 return vllm_embedding, vision_hidden_states def get_audio_embedding_streaming( self, data, use_extra_context=False, prefix_extra_frames=1, suffix_extra_frames=1, cnn_min_length=None, ): """Extract audio embeddings in a streaming manner using cached key-value pairs. This method processes incoming audio features incrementally and stores/updates `past_key_values` for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended for streaming scenarios. Args: data (dict): - **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`. - **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch. use_extra_context (bool): If True, assumes input contains extra frames for CNN context. prefix_extra_frames (int): Number of prefix extra frames. suffix_extra_frames (int): Number of suffix extra frames. cnn_min_length (int): Minimum length for CNN input padding. Returns: List[List[torch.Tensor]]: audio embeddings """ wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] # exist audio if len(wavforms) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape assert batch_size == 1 max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # whisper's past_key_values management (core) if self.audio_past_key_values is not None: cache_length = self.audio_past_key_values[0][0].shape[2] apm_max_len = self.apm.embed_positions.weight.shape[0] if cache_length + max_seq_len >= apm_max_len: logger.warning( f"audio_past_key_values length {cache_length + max_seq_len} exceed {apm_max_len}, reset." ) self.audio_past_key_values = None # build attention mask (bidirectional attention, same as offline mode) batch_size, _, max_mel_seq_len = wavforms.shape current_seq_len = (max_mel_seq_len - 1) // 2 + 1 # if use extra context, need to adjust sequence length if use_extra_context: # calculate actual sequence length after removing redundancy # conv2's stride=2, so the mapping from mel frames to output frames is ceil(x/2) prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0 suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0 current_seq_len = current_seq_len - prefix_to_remove - suffix_to_remove # calculate history length (if there is KV cache) if self.audio_past_key_values is not None: past_len = self.audio_past_key_values[0][0].shape[2] # get history sequence length total_seq_len = past_len + current_seq_len else: past_len = 0 total_seq_len = current_seq_len # create bidirectional attention mask (full attention) audio_attention_mask = torch.zeros( (batch_size, 1, current_seq_len, total_seq_len), dtype=self.apm.conv1.weight.dtype, device=wavforms.device, ) # Step 1: APM processing audio_outputs = self.apm( wavforms, past_key_values=self.audio_past_key_values, use_cache=True, output_hidden_states=True, attention_mask=audio_attention_mask, use_extra_context=use_extra_context, prefix_extra_frames=prefix_extra_frames, suffix_extra_frames=suffix_extra_frames, cnn_min_length=cnn_min_length, ) if hasattr(self, "audio_encoder_layer"): audio_states = audio_outputs.hidden_states[self.audio_encoder_layer] else: audio_states = audio_outputs.last_hidden_state self.audio_past_key_values = audio_outputs.past_key_values # Step 2: Projection audio_embeds = self.audio_projection_layer(audio_states) # Step 3: Pooling audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) num_audio_tokens = feature_lens_after_pooling final_audio_embeds = [] idx = 0 for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds else: return final_audio_embeds else: return [] def get_audio_embedding(self, data, chunk_length=-1, dummy=True): wavforms = data.get("audio_features", []) # (bs, 80, frames) or [], multi audios need filled in advance audio_feature_lens_raw = data.get("audio_feature_lens", []) # list, [[x1, x2], [y1], [z1]] if len(wavforms) > 0: audio_feature_lens = torch.hstack(audio_feature_lens_raw) batch_size, _, max_mel_seq_len = wavforms.shape max_seq_len = (max_mel_seq_len - 1) // 2 + 1 # Create a sequence tensor of shape (batch_size, max_seq_len) seq_range = ( torch.arange( 0, max_seq_len, dtype=audio_feature_lens.dtype, device=audio_feature_lens.device, ) .unsqueeze(0) .expand(batch_size, max_seq_len) ) lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len) # Create mask padding_mask = seq_range >= lengths_expand # 1 for padded values audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( batch_size, 1, max_seq_len, max_seq_len ) audio_attention_mask = audio_attention_mask_.to( dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device ) if chunk_length > 0: chunk_num_frame = int(chunk_length * 50) chunk_mask = self.subsequent_chunk_mask( size=max_seq_len, chunk_size=chunk_num_frame, num_left_chunks=-1, device=audio_attention_mask_.device, ) audio_attention_mask_ = torch.logical_or(audio_attention_mask_, torch.logical_not(chunk_mask)) audio_attention_mask[audio_attention_mask_] = float("-inf") audio_states = self.apm( wavforms, output_hidden_states=True, attention_mask=audio_attention_mask ).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) _, feature_lens_after_pooling = self._get_feat_extract_output_lengths(audio_feature_lens) num_audio_tokens = feature_lens_after_pooling final_audio_embeds = [] idx = 0 for i in range(len(audio_feature_lens_raw)): target_audio_embeds = [] for _ in range(len(audio_feature_lens_raw[i])): target_audio_embeds.append(audio_embeds[idx, : num_audio_tokens[idx], :]) idx += 1 final_audio_embeds.append(target_audio_embeds) return final_audio_embeds elif self.training and dummy: dtype = self.apm.embed_positions.weight.dtype device = self.apm.embed_positions.weight.device dummy_wavs = torch.zeros((1, 80, 100), device=device, dtype=dtype) audio_states = self.apm(dummy_wavs, output_hidden_states=True).hidden_states[self.audio_encoder_layer] audio_embeds = self.audio_projection_layer(audio_states) audio_embeds = audio_embeds.transpose(1, 2) audio_embeds = self.audio_avg_pooler(audio_embeds) audio_embeds = audio_embeds.transpose(1, 2) return [audio_embeds] else: return [] def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False): """ Args: data: input_embeddings: chunk_length: whisper use full attention or chunk attention stream_input: use streaming audio embedding or not Returns: final embeddings with audio feature """ if stream_input: audio_embeddings = self.get_audio_embedding_streaming(data) else: audio_embeddings = self.get_audio_embedding(data, chunk_length) bs = len(input_embeddings) if len(data.get("audio_features", [])) > 0: assert len(audio_embeddings) == len(input_embeddings) if len(audio_embeddings) > 0: audio_bounds = data["audio_bounds"] if self.config.stream_input: assert bs == 1, "audio stream_input mode only support batch size 1" for i in range(bs): audio_embs = torch.cat(audio_embeddings[i], dim=0).to( device=input_embeddings.device, dtype=input_embeddings.dtype ) audio_start_pos = 0 for bound in audio_bounds[i]: audio_len = bound[1] - bound[0] input_embeddings[i, bound[0] : bound[1]] = audio_embs[ audio_start_pos : audio_start_pos + audio_len, : ] audio_start_pos += audio_len else: for i in range(bs): audio_embs = audio_embeddings[i] bounds = audio_bounds[i] for embs, bound in zip(audio_embs, bounds): audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to( input_embeddings.device ) if embs.shape[0] != len(audio_indices): raise ValueError( f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} " f"to input indices of length {len(audio_indices)}" ) input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype) elif self.training: for i in range(bs): # dummy audio_embedings input_embeddings += audio_embeddings[0].mean() * 0 return input_embeddings def forward(self, data, **kwargs): vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data) vllm_embedding = self.get_omni_embedding( data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length, ) position_ids = data["position_ids"] if position_ids.dtype != torch.int64: position_ids = position_ids.long() return self.llm( input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs, ) def _decode(self, inputs_embeds, tokenizer, attention_mask, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] outputs = self.llm.generate( inputs_embeds=inputs_embeds, pad_token_id=0, eos_token_id=terminators, attention_mask=attention_mask, output_hidden_states=True, return_dict_in_generate=True, **kwargs, ) return outputs def _decode_stream(self, inputs_embeds, tokenizer, **kwargs): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] streamer = TextIteratorStreamer(tokenizer=tokenizer) generation_config = { "inputs_embeds": inputs_embeds, "pad_token_id": 0, "eos_token_id": terminators, "streamer": streamer, } generation_config.update(kwargs) thread = Thread(target=self.llm.generate, kwargs=generation_config) thread.start() return streamer def _decode_text(self, result_ids, tokenizer): terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators] result_text = [] for result in result_ids: result = result[result != 0] if result[0] == tokenizer.bos_id: result = result[1:] if result[-1] in terminators: result = result[:-1] result_text.append(tokenizer.decode(result)) return result_text @torch.inference_mode() def generate( self, input_ids=None, pixel_values=None, tgt_sizes=None, audio_features=None, audio_feature_lens=None, image_bound=None, audio_bounds=None, spk_bounds=None, attention_mask=None, tokenizer=None, vision_hidden_states=None, stream=False, **kwargs, ): assert input_ids is not None assert len(input_ids) == len(pixel_values) model_inputs = { "input_ids": input_ids, "audio_features": audio_features, "audio_feature_lens": audio_feature_lens, "image_bound": image_bound, "audio_bounds": audio_bounds, "spk_bounds": spk_bounds, } if vision_hidden_states is None: model_inputs["pixel_values"] = pixel_values model_inputs["tgt_sizes"] = tgt_sizes else: model_inputs["vision_hidden_states"] = vision_hidden_states with torch.inference_mode(): model_inputs["inputs_embeds"], vision_hidden_states = self.get_vllm_embedding(model_inputs) model_inputs["inputs_embeds"] = self.get_omni_embedding( model_inputs, input_embeddings=model_inputs["inputs_embeds"], chunk_length=self.config.audio_chunk_length, ) if stream: result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs) outputs = {} # if stream return TextIteratorStreamer and output is empty else: outputs = self._decode(model_inputs["inputs_embeds"], tokenizer, attention_mask, **kwargs) result = self._decode_text(outputs.sequences, tokenizer) return result, outputs def _build_streaming_mask(self, tts_tokens_len): tts_sequence_full_length = 1 + self.tts.streaming_text_reserved_len + 1 streaming_attention_mask = torch.zeros(tts_sequence_full_length, dtype=torch.int8) streaming_attention_mask[0 : 1 + 1 + tts_tokens_len + 1] = 1 streaming_attention_mask[-1] = 1 return streaming_attention_mask def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048): spk_embeds = self._get_last_spk_embeds(inputs, outputs) text = text.split("<|tts_bos|>")[-1] gen_text = text.split("<|tts_eos|>")[0] tts_text, tts_token_lens = self.prepare_tts_text(gen_text) tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False) tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to(self.device, dtype=torch.long) streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device) logits_warpers, logits_processors = gen_logits( num_code=626, top_p=self.tts.top_p, top_k=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty, ) condition_length = 1 + self.tts.streaming_text_reserved_len + 1 dtype = self.tts.emb_text.weight.dtype emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device) past_key_values = [ ( torch.zeros( 1, self.tts.config.num_attention_heads, condition_length - 1, self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, device=self.tts.device, ), torch.zeros( 1, self.tts.config.num_attention_heads, condition_length - 1, self.tts.config.hidden_size // self.tts.config.num_attention_heads, dtype=emb.dtype, device=self.tts.device, ), ) for _ in range(self.tts.config.num_hidden_layers) ] audio_input_ids = torch.zeros( 1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device, ) eos_lab = False for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)): if chunk_idx == 0: begin = chunk_idx * self.tts.streaming_text_chunk_size + 0 end = (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1 else: begin = chunk_idx * self.tts.streaming_text_chunk_size + 1 end = min( (chunk_idx + 1) * self.tts.streaming_text_chunk_size + 1, condition_length - 1, ) if end - begin > 0: text_input_ids = tts_input_ids[:, begin:end] position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0) if begin == 0: past_key_values = self.tts.prefill_text( input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values, lm_spk_emb_last_hidden_states=spk_embeds, ) else: past_key_values = self.tts.prefill_text( input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values, ) outputs = self.tts.generate( input_ids=audio_input_ids, past_key_values=past_key_values, streaming_tts_text_mask=streaming_tts_text_mask, max_new_token=output_chunk_size, force_no_stop=self.force_no_stop, temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), logits_warpers=logits_warpers, logits_processors=logits_processors, ) audio_input_ids = outputs.audio_input_ids past_key_values = outputs.past_key_values if outputs.finished: eos_lab = True break if not eos_lab: while True: outputs = self.tts.generate( input_ids=audio_input_ids, past_key_values=past_key_values, streaming_tts_text_mask=streaming_tts_text_mask, max_new_token=output_chunk_size, force_no_stop=self.force_no_stop, temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device), eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device), logits_warpers=logits_warpers, logits_processors=logits_processors, ) audio_input_ids = outputs.audio_input_ids past_key_values = outputs.past_key_values if outputs.finished: break if outputs.new_ids.shape[1] > tts_max_new_tokens: break @staticmethod def prepare_generation_config(do_sample, max_new_tokens=50, min_new_tokens=0, **kwargs): num_beams = kwargs.get("num_beams", 3) generation_config = { "num_beams": num_beams, "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.02, } if do_sample: generation_config.update( { "top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.02, } ) elif num_beams > 1: generation_config.update({"num_beams": num_beams, "repetition_penalty": 1.2, "do_sample": False}) else: generation_config.update({"do_sample": False, "repetition_penalty": 1.02}) generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) generation_config["min_new_tokens"] = min_new_tokens generation_config["max_new_tokens"] = max_new_tokens return generation_config def prepare_processor(self, processor=None, tokenizer=None): if processor is not None: self.processor = processor if not hasattr(self, "processor") or self.processor is None: self.processor = MiniCPMOProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True) if tokenizer is not None: self.processor.tokenizer = tokenizer @torch.inference_mode() def chat( self, image=None, msgs=None, vision_hidden_states=None, max_new_tokens=4096, min_new_tokens=0, do_sample=True, max_inp_length=8192, max_slice_nums=None, use_image_id=None, enable_thinking=False, use_tts_template=False, generate_audio=False, output_audio_path=None, output_tts_inputs_embeds_path=None, omni_mode=False, teacher_forcing=False, return_prompt=False, tts_proj_layer=-1, tts_sampling_params: TTSSamplingParams = TTSSamplingParams(), merge_audio_from_same_content=True, stream=False, stream_input=False, tokenizer=None, processor=None, **kwargs, ): from PIL import Image batched = isinstance(msgs[0], list) msgs_list = msgs images_list = image if not batched: images_list, msgs_list = [images_list], [msgs_list] else: assert images_list is None, "Please integrate image to msgs when using batch inference." images_list = [None] * len(msgs_list) assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same." self.prepare_processor(processor=processor, tokenizer=tokenizer) prompts_lists = [] input_images_list = [] input_audios_list = [] audio_parts_list = [] for image, msgs in zip(images_list, msgs_list): if isinstance(msgs, str): msgs = json.loads(msgs) copy_msgs = deepcopy(msgs) assert len(msgs) > 0, "msgs is empty" assert do_sample or not stream, "if use stream mode, make sure do_sample=True" if image is not None and isinstance(copy_msgs[0]["content"], str): copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] images = [] audios = [] audio_parts = [] for i, msg in enumerate(copy_msgs): role = msg["role"] content = msg["content"] assert role in ["system", "user", "assistant"] if i == 0: assert role in ["user", "system"], "The role of first msg should be user" # Normalize structured content (OpenAI format) to native format content = normalize_content(content) cur_msgs = [] for c in content: if isinstance(c, Image.Image): images.append(c) cur_msgs.append("./") elif isinstance(c, np.ndarray): # audio audios.append(c) audio_parts.append(i) cur_msgs.append("") use_tts_template = True elif isinstance(c, str): cur_msgs.append(c) if omni_mode or stream_input: msg["content"] = "".join(cur_msgs) else: msg["content"] = "\n".join(cur_msgs) prompts_lists.append( self.processor.tokenizer.apply_chat_template( copy_msgs, tokenize=False, add_generation_prompt=False if teacher_forcing else True, use_tts_template=use_tts_template, enable_thinking=enable_thinking, ) ) input_images_list.append(images) input_audios_list.append(audios) audio_parts_list.append(audio_parts) if not merge_audio_from_same_content: audio_parts_list = None inputs = self.processor( prompts_lists, input_images_list, input_audios_list, audio_parts_list, max_slice_nums=max_slice_nums, use_image_id=use_image_id, stream_input=stream_input, return_tensors="pt", max_length=max_inp_length, ).to(self.device) generation_config = self.prepare_generation_config( do_sample=do_sample, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, **kwargs ) generation_config.pop("max_new_tokens", None) inputs.pop("image_sizes") # teacher_forcing = True => generate audio with given text with torch.inference_mode(): res, outputs = self.generate( **inputs, tokenizer=self.processor.tokenizer, max_new_tokens=1 if teacher_forcing else max_new_tokens, vision_hidden_states=vision_hidden_states, stream=stream, **generation_config, ) # spk bound and tts bound tts_bos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_bos|>") tts_eos_token = self.processor.tokenizer.convert_tokens_to_ids("<|tts_eos|>") # Combine input_ids and generated sequences to get complete sequence input_ids = inputs["input_ids"][0] generated_ids = outputs.sequences[0] # Combine by concatenating input_ids with the new tokens from generated sequence full_sequence = torch.cat([input_ids, generated_ids]) # Update the sequences in outputs full_sequences = full_sequence.unsqueeze(0) outputs["full_sequences"] = full_sequences tts_bos_indices = [] tts_eos_indices = [] for i, x in enumerate(full_sequences[0]): if x == tts_bos_token: # tts_bos + 1 is the position of the first tts, so that it is convenient to slice hidden states for tts tts_bos_indices.append(i + 1) elif x == tts_eos_token: if teacher_forcing and i == len(full_sequences[0]) - 1: continue tts_eos_indices.append(i) tts_bos_idx = tts_bos_indices[-1] if tts_bos_indices else -1 # Use None instead of -1 when no EOS token found, so that slice [start:None] # means "to the end" rather than [start:-1] which excludes the last element tts_eos_idx = tts_eos_indices[-1] if tts_eos_indices else None tts_bound = (tts_bos_idx, tts_eos_idx) answer = res[0] if answer is not None: answer = answer.split("<|tts_eos|>")[0] if use_tts_template and generate_audio and output_audio_path: import soundfile as sf try: generated_waveform = self._generate_speech_non_streaming( outputs=outputs, tts_bound=tts_bound, tts_proj_layer=tts_proj_layer, audio_prompt=( input_audios_list[0][0] if len(input_audios_list) > 0 and len(input_audios_list[0]) > 0 else None ), output_tts_inputs_embeds_path=output_tts_inputs_embeds_path, tts_sampling_params=tts_sampling_params, ) if isinstance(generated_waveform, torch.Tensor): sf.write(output_audio_path, generated_waveform.cpu().numpy(), samplerate=24000) elif isinstance(generated_waveform, np.ndarray): sf.write(output_audio_path, generated_waveform, samplerate=24000) logger.debug(f"audio saved to {output_audio_path}") except: import traceback traceback.print_exc() if return_prompt: return answer, prompts_lists[0] else: return answer @torch.inference_mode() def _generate_speech_non_streaming( self, outputs, tts_bound, tts_proj_layer, audio_prompt, output_tts_inputs_embeds_path=None, tts_sampling_params: TTSSamplingParams = TTSSamplingParams(), ): last_hidden_states = [hs[tts_proj_layer] for hs in outputs.hidden_states] last_hidden_states = torch.vstack([i[0] for i in last_hidden_states]) spk_embeds = ( torch.ones([0, self.tts.config.hidden_size]).to(last_hidden_states.device).to(last_hidden_states.dtype) ) if self.tts.condition_type == "hidden_text_merge": llm_tokens = outputs["full_sequences"][0][tts_bound[0] : tts_bound[1]] llm_tokens = torch.tensor(llm_tokens, device=self.tts.emb_text.weight.device, dtype=torch.long) llm_embeds = self.tts.emb_text(llm_tokens) # make sure emb_text is compatible with llm vocab size hidden_embeds = last_hidden_states[tts_bound[0] : tts_bound[1]] hidden_embeds = self.tts.projector_semantic(hidden_embeds) if self.tts.config.normalize_projected_hidden: hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1) tts_embeds = llm_embeds + hidden_embeds else: raise NotImplementedError audio_bos = [self.tts.audio_bos_token_id] audio_bos = torch.tensor(audio_bos, device=self.tts.emb_text.weight.device, dtype=torch.long) audio_bos_embeds = self.tts.emb_text(audio_bos) text_eos_embed = self.tts.emb_text( torch.tensor( [self.tts.config.text_eos_token_id], device=self.tts.emb_text.weight.device, dtype=torch.long, ) ) inputs_embeds = torch.cat([spk_embeds, tts_embeds, text_eos_embed, audio_bos_embeds], dim=0).unsqueeze(0) # save inputs_embeds to file if output_tts_inputs_embeds_path: torch.save(inputs_embeds, output_tts_inputs_embeds_path) outputs = self.tts.generate( inputs_embeds=inputs_embeds, sampling_params=tts_sampling_params, eos_token=torch.tensor( [self.tts.config.num_audio_tokens - 1], dtype=torch.long, device=self.tts.device, ), ) import io import soundfile as sf generated_tokens = outputs.new_ids.squeeze(-1) reference_audio = audio_prompt prompt_wav_path = None if reference_audio is not None: logger.debug("use reference audio in data to generate waveform") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: prompt_wav_path = tmp_wav.name sf.write(prompt_wav_path, reference_audio, 16000) wav_bytes = self.tts.audio_tokenizer( generated_tokens.squeeze(0).tolist(), prompt_wav_path, ) # convert wav bytes back to tensor for caller compatibility waveform, sr = sf.read(io.BytesIO(wav_bytes)) return torch.tensor(waveform, dtype=torch.float32) @torch.inference_mode() def init_token2wav_cache(self, prompt_speech_16k): import soundfile as sf if hasattr(self.tts.audio_tokenizer, "set_stream_cache"): self.tts.audio_tokenizer.cache = None with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav: prompt_wav_path = tmp_wav.name sf.write(prompt_wav_path, prompt_speech_16k, 16000) flow_cache_base, hift_cache_base = self.tts.audio_tokenizer.set_stream_cache(prompt_wav_path) self.token2wav_cache = { "flow_cache_base": torch_clone_recursive(flow_cache_base), "hift_cache_base": torch_clone_recursive(hift_cache_base), } else: model_input = self.tts.audio_tokenizer.frontend.frontend_token2wav( speech_tokens=torch.zeros(1, 1, dtype=torch.long, device=self.tts.device), speech_16k=None, prompt_speech_16k=prompt_speech_16k, resample_rate=self.tts.audio_tokenizer.sample_rate, prompt_speech=None, ) prompt_token = model_input["flow_prompt_speech_token"] prompt_feat = model_input["prompt_speech_feat"] embedding = model_input["flow_embedding"] if self.tts.audio_tokenizer.fp16: prompt_feat = prompt_feat.to(torch.half) embedding = embedding.to(torch.half) prepared_cache = self.tts.audio_tokenizer.model.prepare_cache_from_prompt( prompt_token=prompt_token, prompt_feat=prompt_feat, embedding=embedding, n_timesteps=self.tts.config.s3_stream_n_timesteps, code_chunk_size=self.tts.config.s3_stream_chunk_size, chunk_prelook_size=self.tts.config.s3_stream_prelook_size, use_attn_idx=False, ) self.token2wav_cache = prepared_cache # for sliding window def _ensure_dynamic_cache(self): cache = self.llm_past_key_values if cache is None: return None cache = as_dynamic_cache(cache) if isinstance(cache, DynamicCache): self.llm_past_key_values = cache return cache return None def _get_kv_cache_length(self, cache=None): cache = cache if cache is not None else self.llm_past_key_values return get_kv_cache_length(cache) # todo: not-used del? def _rebuild_cache_from_history(self): preserved_ids: List[torch.Tensor] = [] for entry in self._omni_chunk_history: ids = entry.get("input_ids") if ids is None or not isinstance(ids, torch.Tensor) or ids.numel() == 0: continue preserved_ids.append(ids.to(self.device)) if not preserved_ids: self.llm_past_key_values = None self.streaming_position_offset = 0 self._rope_inv_freq_cache.clear() return concat_ids = torch.cat(preserved_ids, dim=1) attention_mask = torch.ones((1, concat_ids.shape[1]), dtype=torch.bool, device=self.device) outputs = self.llm( input_ids=concat_ids, attention_mask=attention_mask, use_cache=True, return_dict=True, ) self.llm_past_key_values = outputs.past_key_values self.streaming_position_offset = 0 self._rope_inv_freq_cache.clear() def _get_rope_theta(self) -> float: return float(getattr(self.llm.config, "rope_theta", 10000.0)) def _realign_rotary_suffix( self, suffix_keys: torch.Tensor, old_positions: torch.Tensor, new_positions: torch.Tensor, ) -> torch.Tensor: return realign_rotary_suffix( suffix_keys, old_positions, new_positions, rope_theta=self._get_rope_theta(), inv_freq_cache=self._rope_inv_freq_cache, ) def _encode_text(self, tokenizer, text) -> Optional[torch.Tensor]: if tokenizer is None or not text: return None ids = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"] return ids.to(self.device) @staticmethod def _safe_decode(tokenizer, input_ids): if tokenizer is None or input_ids is None: return None if isinstance(input_ids, torch.Tensor): ids = input_ids.cpu().tolist() if ids and isinstance(ids[0], list): ids = ids[0] else: ids = input_ids try: return tokenizer.decode(ids, skip_special_tokens=False) except Exception: return None def _finalize_round( self, round_id: Optional[int], cache_before: int, assistant_input_ids: Optional[torch.Tensor] = None ): if round_id is None: self._pending_round_id = None return cache_after = self._get_kv_cache_length() if assistant_input_ids is not None: assistant_len = assistant_input_ids.shape[1] else: assistant_len = max(cache_after - cache_before, 0) if assistant_len > 0: self._register_chunk( assistant_len, "assistant", round_id=round_id, input_ids=assistant_input_ids, tokenizer=self.processor.tokenizer if hasattr(self, "processor") else None, ) self._pending_round_id = None self._next_round_id += 1 def _register_chunk( self, seq_len: int, chunk_type: str, *, round_id: int, input_ids=None, tokenizer=None, ) -> None: if seq_len <= 0: return entry = {"length": int(seq_len), "type": chunk_type, "round": round_id} if input_ids is not None: entry["input_ids"] = input_ids.clone().detach() entry["decoded"] = self._safe_decode(tokenizer, entry["input_ids"]) else: entry["input_ids"] = None entry["decoded"] = None self._omni_chunk_history.append(entry) if chunk_type == "system": self.streaming_text_preserve = max(self.streaming_text_preserve, entry["length"]) def _drop_tokens_from_cache(self, length: int, cache: DynamicCache) -> bool: """Drop tokens from cache using the utility function.""" _, new_offset, success = drop_tokens_from_cache( cache=cache, length=length, preserve=self.streaming_text_preserve, position_offset=self.streaming_position_offset, rope_theta=self._get_rope_theta(), inv_freq_cache=self._rope_inv_freq_cache, ) if success: self.streaming_position_offset = new_offset return success def _drop_next_round(self, cache: DynamicCache) -> bool: seen_rounds = set() for entry in self._omni_chunk_history: round_id = entry.get("round") if round_id is None or round_id in seen_rounds: continue seen_rounds.add(round_id) round_entries = [e for e in self._omni_chunk_history if e.get("round") == round_id] if any(e.get("type") == "system" for e in round_entries): continue if self._drop_round(round_id, cache): return True return False def _drop_round(self, round_id: int, cache: DynamicCache) -> bool: entries = [e for e in self._omni_chunk_history if e.get("round") == round_id] if not entries: return False total_len = sum(e["length"] for e in entries) if total_len <= 0: for e in entries: self._omni_chunk_history.remove(e) return False if not self._drop_tokens_from_cache(total_len, cache): return False for e in entries: self._omni_chunk_history.remove(e) return True def _enforce_text_window(self) -> None: if not self.streaming_window_enabled: return cache = self._ensure_dynamic_cache() if cache is None: return high_limit = max(0, int(self.streaming_window_config.text_window_high_tokens)) low_limit = max(0, int(self.streaming_window_config.text_window_low_tokens)) if high_limit <= 0: return target = max(0, low_limit) total_len = self._get_kv_cache_length(cache) if total_len <= high_limit: return dropped_any = False while total_len > target: if not self._drop_next_round(cache): break dropped_any = True total_len = self._get_kv_cache_length(cache) # snapshot, vad def save_speculative_snapshot(self) -> SpeculativeSnapshot: """Internal method: save speculative snapshot. Called at the start of streaming_generate, saves to self._speculative_snapshot. Save strategy: - LLM KV Cache: only record length (restore by truncation, zero extra VRAM) - Audio KV Cache: deep clone (as generate sets it to None) - Mel processor: full state snapshot (including buffer) """ # get LLM cache information llm_cache_length = self._get_kv_cache_length() llm_cache_checksum = None if self.llm_past_key_values is not None and hasattr(self.llm_past_key_values, "key_cache"): if len(self.llm_past_key_values.key_cache) > 0: llm_cache_checksum = self.llm_past_key_values.key_cache[0].sum().item() # get audio cache length and clone audio_past_key_values audio_cache_length = 0 audio_cache_checksum = None audio_past_key_values_clone = None if self.audio_past_key_values is not None: # handle DynamicCache format (Whisper encoder may return this format) if isinstance(self.audio_past_key_values, DynamicCache): if hasattr(self.audio_past_key_values, "key_cache") and len(self.audio_past_key_values.key_cache) > 0: audio_cache_length = self.audio_past_key_values.key_cache[0].shape[2] audio_cache_checksum = self.audio_past_key_values.key_cache[0].sum().item() # deep clone DynamicCache cloned_cache = DynamicCache() for k, v in zip(self.audio_past_key_values.key_cache, self.audio_past_key_values.value_cache): cloned_cache.update(k.clone(), v.clone(), layer_idx=len(cloned_cache.key_cache)) audio_past_key_values_clone = cloned_cache # handle EncoderDecoderCache format elif isinstance(self.audio_past_key_values, EncoderDecoderCache): self_attn_cache = self.audio_past_key_values.self_attention_cache if hasattr(self_attn_cache, "key_cache") and len(self_attn_cache.key_cache) > 0: audio_cache_length = self_attn_cache.key_cache[0].shape[2] audio_cache_checksum = self_attn_cache.key_cache[0].sum().item() # deep clone EncoderDecoderCache cloned_self_attn = DynamicCache() if hasattr(self_attn_cache, "key_cache"): for k, v in zip(self_attn_cache.key_cache, self_attn_cache.value_cache): cloned_self_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_self_attn.key_cache)) cross_attn_cache = self.audio_past_key_values.cross_attention_cache cloned_cross_attn = DynamicCache() if hasattr(cross_attn_cache, "key_cache"): for k, v in zip(cross_attn_cache.key_cache, cross_attn_cache.value_cache): cloned_cross_attn.update(k.clone(), v.clone(), layer_idx=len(cloned_cross_attn.key_cache)) audio_past_key_values_clone = EncoderDecoderCache(cloned_self_attn, cloned_cross_attn) # handle tuple format (compatible with old format) elif isinstance(self.audio_past_key_values, tuple) and len(self.audio_past_key_values) > 0: audio_cache_length = self.audio_past_key_values[0][0].shape[2] audio_cache_checksum = self.audio_past_key_values[0][0].sum().item() # deep clone audio_past_key_values (tuple of tuples of tensors) audio_past_key_values_clone = tuple( tuple(t.clone() for t in layer_cache) for layer_cache in self.audio_past_key_values ) # get mel processor snapshot mel_processor_snapshot = None mel_buffer_checksum = None if hasattr(self, "processor") and self.processor is not None: mel_processor_snapshot = self.processor.get_streaming_snapshot() if mel_processor_snapshot: buf = mel_processor_snapshot.get("buffer") if buf is not None and len(buf) > 0: mel_buffer_checksum = float(buf.sum()) # save RNG state (important: for deterministic dithering and other random operations after restoration) rng_state_cpu = torch.get_rng_state() rng_state_cuda = None if torch.cuda.is_available() and self.device.type == "cuda": rng_state_cuda = torch.cuda.get_rng_state(self.device) # create snapshot snapshot = SpeculativeSnapshot( llm_cache_length=llm_cache_length, audio_cache_length=audio_cache_length, new_user_msg=self.new_user_msg, llm_generated=self.llm_generated, llm_generate_completed=self.llm_generate_completed, next_round_id=self._next_round_id, pending_round_id=self._pending_round_id, omni_chunk_history_length=len(self._omni_chunk_history), tts_last_turn_tokens=self.tts_last_turn_tokens.clone() if self.tts_last_turn_tokens is not None else None, audio_chunk_idx=self.audio_chunk_idx, mel_processor_snapshot=mel_processor_snapshot, audio_past_key_values=audio_past_key_values_clone, timestamp=time.time(), # debug fields llm_cache_checksum=llm_cache_checksum, audio_cache_checksum=audio_cache_checksum, mel_buffer_checksum=mel_buffer_checksum, # RNG state rng_state_cpu=rng_state_cpu, rng_state_cuda=rng_state_cuda, ) return snapshot def restore_speculative_snapshot(self, snapshot=None) -> bool: """Restore speculative snapshot - called when VAD speculation fails. Restores model state to before streaming_generate was called, allowing continued streaming_prefill for newly arrived audio. Notes: - Snapshot is saved when streaming_generate is called with enable_speculative_snapshot=True - This method uses the most recent snapshot for restoration - Snapshot is cleared after restore, cannot be called repeatedly Returns: bool: Whether restoration was successful """ snapshot = snapshot or getattr(self, "_speculative_snapshot", None) if snapshot is None: return False try: current_cache_length = self._get_kv_cache_length() current_history_length = len(self._omni_chunk_history) # 1. truncate LLM KV Cache if current_cache_length > snapshot.llm_cache_length: self._truncate_llm_cache(snapshot.llm_cache_length) # 2. restore Audio KV Cache (important: restore from cloned copy) # because streaming_generate will set audio_past_key_values to None self.audio_past_key_values = snapshot.audio_past_key_values # 3. restore session state self.new_user_msg = snapshot.new_user_msg self.llm_generated = snapshot.llm_generated self.llm_generate_completed = snapshot.llm_generate_completed # 4. restore Round management self._next_round_id = snapshot.next_round_id self._pending_round_id = snapshot.pending_round_id # 5. truncate chunk history if current_history_length > snapshot.omni_chunk_history_length: self._omni_chunk_history = self._omni_chunk_history[: snapshot.omni_chunk_history_length] # 6. restore TTS state self.tts_last_turn_tokens = snapshot.tts_last_turn_tokens # 7. restore streaming processor state self.audio_chunk_idx = snapshot.audio_chunk_idx # 8. restore mel processor state (important: otherwise subsequent prefill will fail due to frame number mismatch) if ( snapshot.mel_processor_snapshot is not None and hasattr(self, "processor") and self.processor is not None ): self.processor.restore_streaming_snapshot(snapshot.mel_processor_snapshot) # 9. restore RNG state (important: ensure determinism of dithering and other random operations after restoration) if snapshot.rng_state_cpu is not None: torch.set_rng_state(snapshot.rng_state_cpu) if snapshot.rng_state_cuda is not None and torch.cuda.is_available(): torch.cuda.set_rng_state(snapshot.rng_state_cuda, self.device) # 10. clean up temporary states generated during generation if hasattr(self, "_streaming_generated_token_ids"): del self._streaming_generated_token_ids if hasattr(self, "_last_streaming_text"): del self._last_streaming_text # 11. clear snapshot (can only be restored once) self._speculative_snapshot = None return True except Exception as e: import traceback logger.error(traceback.format_exc()) return False def has_speculative_snapshot(self) -> bool: return getattr(self, "_speculative_snapshot", None) is not None def clear_speculative_snapshot(self) -> None: if hasattr(self, "_speculative_snapshot"): self._speculative_snapshot = None def _truncate_llm_cache(self, target_length: int) -> None: if self.llm_past_key_values is None: return cache = self._ensure_dynamic_cache() if cache is None: return current_length = self._get_kv_cache_length(cache) if current_length <= target_length: return # truncate each layer of cache for layer_idx in range(len(cache.key_cache)): if cache.key_cache[layer_idx].numel() > 0: cache.key_cache[layer_idx] = cache.key_cache[layer_idx][:, :, :target_length, :].contiguous() cache.value_cache[layer_idx] = cache.value_cache[layer_idx][:, :, :target_length, :].contiguous() # update cache metadata cache.crop(target_length) cache._seen_tokens = target_length @torch.inference_mode() def streaming_prefill( self, session_id, msgs, omni_mode=True, max_slice_nums=None, use_tts_template=True, enable_thinking=False, is_last_chunk=False, # for audio chunk, if is the last chunk, set to True tokenizer=None, processor=None, **kwargs, ): from PIL import Image assert session_id is not None, "session_id cannot be None" self.is_first = self.session_id is None or session_id != self.session_id self.prepare_processor(processor=processor, tokenizer=tokenizer) images = [] audios = [] assert len(msgs) == 1 copy_msgs = deepcopy(msgs) msg = copy_msgs[0] assert msg["role"] in ["system", "user", "assistant"] is_not_system_prefill = msg["role"] != "system" content = msg["content"] cur_msgs = [] for j, c in enumerate(content): if isinstance(c, Image.Image): images.append(c) cur_msgs.append("./") elif isinstance(c, np.ndarray): audios.append(c) cur_msgs.append("") elif isinstance(c, str): cur_msgs.append(c) else: logger.error(f"Invalid content type: {c}, ignore it.") cur_contents = "".join(cur_msgs) if omni_mode else "\n".join(cur_msgs) if msg["role"] in ["system", "assistant"]: self.new_user_msg = True self.audio_past_key_values = None if self.is_first: self.reset_session(reset_token2wav_cache=False) self.session_id = session_id self.init_streaming_processor() if msg["role"] == "user": # no system prefill, the first segment of the first user turn # do not use apply_chat_template, manually build prompt to avoid automatic addition of <|im_end|> prompt = "<|im_start|>user\n" + cur_contents self.new_user_msg = False # mark subsequent segments do not need to add user prefix anymore else: # system or assistant prefill, use apply_chat_template msg["content"] = cur_contents prompt = self.processor.tokenizer.apply_chat_template( copy_msgs, tokenize=False, add_generation_prompt=False, use_tts_template=use_tts_template, enable_thinking=enable_thinking, ) add_special_tokens = True # add bos else: # non-first prefill if self.new_user_msg and msg["role"] == "user": # the first segment of the new user turn if self.llm_generated: if self.llm_generate_completed: prompt = "<|im_end|>\n<|im_start|>user\n" + cur_contents else: prompt = "<|tts_eos|><|im_end|>\n<|im_start|>user\n" + cur_contents else: prompt = "<|im_start|>user\n" + cur_contents self.new_user_msg = False else: # subsequent segments of the same turn, directly use content prompt = cur_contents add_special_tokens = False # when first user audio prefill, ensure audio length satisfies FIRST_CHUNK_MS requirements if is_not_system_prefill and len(audios) > 0 and self.audio_chunk_idx == 0: assert len(audios) == 1, f"streaming mode only supports single audio, currently {len(audios)}" first_chunk_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000) if len(audios[0]) < first_chunk_samples: pad_len = first_chunk_samples - len(audios[0]) audios[0] = np.concatenate([np.zeros(pad_len, dtype=audios[0].dtype), audios[0]]) model_inputs = self.processor( [prompt], [images], [audios], max_slice_nums=1 if max_slice_nums is None else max_slice_nums, use_image_id=False, chunk_input=True, return_tensors="pt", max_length=None, sampling_rate=16000, add_special_tokens=add_special_tokens, online_streaming=is_not_system_prefill, audio_chunk_idx=self.audio_chunk_idx, is_last_chunk=is_last_chunk, ).to(self.device) if len(audios) > 0 and is_not_system_prefill: self.audio_chunk_idx += 1 # 1. prepare input embeddings model_inputs["inputs_embeds"], _ = self.get_vllm_embedding(model_inputs) # get audio embedding with audio_past_key_values inputs_embeds = self.get_omni_embedding( model_inputs, input_embeddings=model_inputs["inputs_embeds"], stream_input=is_not_system_prefill ) if self.is_first: self.audio_past_key_values = None round_id = self._next_round_id self._pending_round_id = round_id chunk_type = "system" if msg["role"] == "system" else ("user" if msg["role"] == "user" else "assistant") seq_len = inputs_embeds.shape[1] self._enforce_text_window() cache_length = self._get_kv_cache_length() attention_mask = torch.ones((1, cache_length + inputs_embeds.shape[1]), dtype=torch.bool, device=self.device) # 2. do prefill outputs = self.llm( past_key_values=self.llm_past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=None, use_cache=True, return_dict=True, ) self.llm_past_key_values = as_dynamic_cache(outputs["past_key_values"]) self._register_chunk( seq_len, chunk_type, round_id=round_id, input_ids=model_inputs["input_ids"], tokenizer=self.processor.tokenizer, ) self._enforce_text_window() if self.force_rope_reindex: self._force_reindex_all_cache() return prompt @torch.inference_mode() def streaming_generate( self, session_id, bos_input=None, generate_audio=True, audio_token_chunk_size=25, # 25 token/s tts_sampling_params: TTSSamplingParams = TTSSamplingParams(), max_new_tokens=256, enable_thinking=False, use_tts_template=True, do_sample=True, enable_speculative_snapshot=False, tokenizer=None, processor=None, # Teacher forcing (only for the "text → hidden → TTS condition" pipeline in streaming_generate) # When enabled: instead of letting the LLM auto-regressively generate the text to be spoken, # it forces the tokens from teacher_forcing_text to be fed in, using the hidden states # corresponding to these tokens to construct the TTS condition, ensuring the output audio matches the input text. teacher_forcing: bool = False, teacher_forcing_text: str = "", **kwargs, ): # save speculative snapshot (before modifying any state) # for VAD speculative snapshot: if speculative snapshot fails, can call restore_speculative_snapshot() to restore # enable_speculative_snapshot=True when enabled, skip (save some overhead) when disabled if enable_speculative_snapshot: self._speculative_snapshot = self.save_speculative_snapshot() # reset buf self.new_user_msg = True self.llm_generated = True self.llm_generate_completed = False self.audio_past_key_values = None self.prepare_processor(processor=processor, tokenizer=tokenizer) # reset current turn generated token IDs if hasattr(self, "_streaming_generated_token_ids"): del self._streaming_generated_token_ids # reset full generated text if hasattr(self, "_last_streaming_text"): del self._last_streaming_text cache = self._ensure_dynamic_cache() cache_length = self._get_kv_cache_length(cache) host_round_id = self._pending_round_id ## in single-turn streaming, each call to streaming_generate needs to reinitialize the streaming_processor, enter the next turn self.init_streaming_processor() # 1) llm generate token and hidden states per chunk=10, 2) tts generate audio token chunk per chunk=25, 3) yield 1 chunk audio token def audio_chunk_generator( bos_input, tokenizer, generate_audio, tts_sampling_params, max_new_tokens, do_sample, teacher_forcing=False, teacher_forcing_text="", **kwargs, ): generate_chunk_size = 10 if bos_input is None: bos_input = "".join( [ "<|im_end|>\n<|im_start|>assistant\n", "" if enable_thinking else self.think_str.replace("\\n", "\n"), "<|tts_bos|>" if use_tts_template else "", ] ) bos_input_ids = tokenizer.encode(bos_input) bos_input_ids = torch.tensor(bos_input_ids, dtype=torch.long, device=self.device).unsqueeze(0) bos_input_embeds = self.llm.get_input_embeddings()(bos_input_ids) generation_inputs_embeds = bos_input_embeds generated_ids = torch.empty((1, 0), dtype=torch.long, device=self.device) num_chunks_decode = (max_new_tokens + generate_chunk_size - 1) // generate_chunk_size conditions = [] # generate chunk by chunk, each chunk has 10 tokens, each chunk takes last hidden states, and pass tokens to tts llm_streaming_generator = ChunkPrefillChunkGenerate( model=self.llm, tokenizer=tokenizer, terminators=["<|tts_eos|>", "<|im_end|>", ""], ) if generate_audio: logits_warpers, logits_processors = gen_logits( num_code=self.tts.config.num_audio_tokens, repetition_penalty=tts_sampling_params.repetition_penalty, top_p=tts_sampling_params.top_p, top_k=tts_sampling_params.top_k, ) tts_streaming_generator = TTSStreamingGenerator( model=self.tts, temperature=tts_sampling_params.temperature, eos_token=torch.tensor( [self.tts.config.num_audio_tokens - 1], dtype=torch.long, device=self.tts.device, ), chunk_size=audio_token_chunk_size, # s3tokenizer 1s = 25token tts_last_turn_tokens=self.tts_last_turn_tokens, logits_processors=logits_processors, logits_warpers=logits_warpers, ) # Teacher forcing branch # This branch does not rely on ChunkPrefillChunkGenerate's sampling logic, instead: # 1) First prefill bos_input (assistant + tts_bos) into llm_past_key_values # 2) Tokenize teacher_forcing_text into token ids # 3) Feed tokens one by one into the LLM (teacher forcing), obtaining the last_hidden_states for each token # 4) Use (token_ids, hidden_states) to construct tts condition, then feed it to TTSStreamingGenerator if teacher_forcing: # --- 1) prefill bos_input,延续 streaming_prefill 的 KV cache --- bos_outputs = self.llm( inputs_embeds=generation_inputs_embeds, past_key_values=self.llm_past_key_values, use_cache=True, output_hidden_states=True, return_dict=True, ) self.llm_past_key_values = bos_outputs.past_key_values if generate_audio: # Give a length-0 tensor as speaker embedding (no speaker embedding) spk_emb = torch.empty( (bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]), dtype=bos_input_embeds.dtype, device=bos_input_embeds.device, ) tts_streaming_generator.spk_emb = spk_emb # --- 2) tokenize teacher_forcing_text --- tf_text = teacher_forcing_text or "" try: forced_input_ids = tokenizer(tf_text, add_special_tokens=False, return_tensors="pt")["input_ids"] except Exception: # Compatible with rare tokenizer return object attributes forced_input_ids = tokenizer(tf_text, add_special_tokens=False, return_tensors="pt").input_ids forced_input_ids = forced_input_ids.to(self.device) total_len = int(forced_input_ids.shape[1]) ptr = 0 # Special case: empty text should also let TTS finish (text_finished=True will automatically concatenate text_eos_embed) if total_len == 0: if not generate_audio: yield forced_input_ids, True return empty_tts_embeds = torch.empty( (1, 0, self.tts.config.hidden_size), dtype=bos_input_embeds.dtype, device=self.device, ) if not hasattr(self, "_streaming_generated_token_ids"): self._streaming_generated_token_ids = [] tts_generator = tts_streaming_generator.generate_with_buffer( condition=empty_tts_embeds, text_finished=True, ) for audio_token_chunk, is_last_audio_chunk in tts_generator: yield audio_token_chunk, is_last_audio_chunk self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens self._last_streaming_text = "" yield None, None return # --- 3) chunk-by-chunk teacher forcing --- while ptr < total_len: end = min(ptr + generate_chunk_size, total_len) chunk_ids = forced_input_ids[:, ptr:end] # [1, chunk_len] chunk_hidden_list = [] for j in range(chunk_ids.shape[1]): tok = chunk_ids[:, j : j + 1] # [1, 1] tok_emb = self.llm.get_input_embeddings()(tok) out = self.llm( inputs_embeds=tok_emb, past_key_values=self.llm_past_key_values, use_cache=True, output_hidden_states=True, return_dict=True, ) self.llm_past_key_values = out.past_key_values chunk_hidden_list.append(out.hidden_states[-1]) # [1, 1, hidden] chunk_hidden = torch.cat(chunk_hidden_list, dim=1) # [1, chunk_len, hidden] text_finished = end >= total_len # Save token IDs cache (external eval script will use _last_streaming_text to write generated_text) if not hasattr(self, "_streaming_generated_token_ids"): self._streaming_generated_token_ids = [] self._streaming_generated_token_ids.extend(chunk_ids[0].tolist()) if not generate_audio: yield chunk_ids, text_finished else: llm_embeds = self.tts.emb_text(chunk_ids) hidden_embeds = self.tts.projector_semantic(chunk_hidden) if self.tts.config.normalize_projected_hidden: hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1) tts_embeds = llm_embeds + hidden_embeds tts_generator = tts_streaming_generator.generate_with_buffer( condition=tts_embeds, text_finished=text_finished, ) for audio_token_chunk, is_last_audio_chunk in tts_generator: yield audio_token_chunk, is_last_audio_chunk ptr = end if text_finished: if generate_audio: self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens break # Finish: decode this round of text if hasattr(self, "_streaming_generated_token_ids"): try: self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids) assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text) self._finalize_round( round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids ) except Exception: self._last_streaming_text = None else: self._last_streaming_text = None # Finally send the end signal if generate_audio: yield None, None else: return return # LLM chunk generate outer loop for chunk_idx in range(num_chunks_decode): is_first_generate_chunk = chunk_idx == 0 output = llm_streaming_generator.chunk_generate( inputs_embeds=generation_inputs_embeds, past_key_values=self.llm_past_key_values, is_first_generate_chunk=is_first_generate_chunk, return_hidden_states=True, chunk_size=generate_chunk_size + 1 * is_first_generate_chunk, do_sample=do_sample, temperature=kwargs.get("temperature", 0.7), top_p=kwargs.get("top_p", 0.8), top_k=kwargs.get("top_k", 100), repetition_penalty=kwargs.get("repetition_penalty", 1.02), length_penalty=kwargs.get("length_penalty", 1.0), all_input_ids=generated_ids, ) if output.chunk_token_ids is None: break if is_first_generate_chunk: if generate_audio: spk_emb = torch.empty( (bos_input_embeds.shape[0], 0, bos_input_embeds.shape[2]), dtype=bos_input_embeds.dtype, device=bos_input_embeds.device, ) tts_streaming_generator.spk_emb = spk_emb if output.finished: yield_chunk_token_ids = output.chunk_token_ids else: # the first chunk generated chunk_size + 1 tokens, we only take the first chunk_size tokens, # the last token is not prefilled, and last hidden states is not obtained yield_chunk_token_ids = output.chunk_token_ids[:, :-1] elif output.finished: yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids], dim=1) else: # in the chunk that is not the first chunk, we need to add the token at the end of the previous chunk, # it is not prefilled into the model to get last hidden states # similarly, the last generated token of subsequent chunks is not prefilled, and last hidden states is not obtained, # so it is not passed out yield_chunk_token_ids = torch.cat([generated_ids[:, -1:], output.chunk_token_ids[:, :-1]], dim=1) if not generate_audio: chunk_generated_text = tokenizer.decode(yield_chunk_token_ids[0]) yield yield_chunk_token_ids, output.finished else: # TTS inner loop # dense connection here is hardcoded to use text-hidden merged as condition llm_embeds = self.tts.emb_text(yield_chunk_token_ids) hidden_embeds = output.last_hidden_states hidden_embeds = self.tts.projector_semantic(hidden_embeds) if self.tts.config.normalize_projected_hidden: # default should be opened hidden_embeds = F.normalize(hidden_embeds, p=2, dim=-1) tts_embeds = llm_embeds + hidden_embeds conditions.append(tts_embeds) # Store token IDs instead of decoded text to avoid UTF-8 multi-byte character truncation if not hasattr(self, "_streaming_generated_token_ids"): self._streaming_generated_token_ids = [] self._streaming_generated_token_ids.extend(yield_chunk_token_ids[0].tolist()) # there is buffer generated, each time exactly returns 25 audio tokens, # the last audio chunk returns audio tokens of variable length, length [0, 25] tts_generator = tts_streaming_generator.generate_with_buffer( condition=tts_embeds, text_finished=output.finished ) for audio_token_chunk, is_last_audio_chunk in tts_generator: yield audio_token_chunk, is_last_audio_chunk generated_ids = torch.cat([generated_ids, output.chunk_token_ids], dim=1) generation_inputs_embeds = output.current_inputs_embeds self.llm_past_key_values = output.past_key_values if output.finished: if generate_audio: self.tts_last_turn_tokens = tts_streaming_generator.tts_last_turn_tokens break # IMPORTANT: Flush remaining TTS buffer when LLM generation ends # This handles BOTH cases: # 1. LLM finished with terminator (output.finished=True) - buffer may still have tokens # 2. LLM hit max chunks limit (output.finished=False) - buffer definitely has tokens if generate_audio: if len(tts_streaming_generator._token_buffer) > 0: batch = torch.cat(tts_streaming_generator._token_buffer, dim=1) yield batch, True tts_streaming_generator._token_buffer = [] if generate_audio: if hasattr(self, "_streaming_generated_token_ids"): try: self._last_streaming_text = tokenizer.decode(self._streaming_generated_token_ids) assistant_input_ids = self._encode_text(tokenizer=tokenizer, text=self._last_streaming_text) self._finalize_round( round_id=host_round_id, cache_before=cache_length, assistant_input_ids=assistant_input_ids ) except Exception: self._last_streaming_text = None else: self._last_streaming_text = None yield None, None else: return # iter for generating text chunk and audio chunk audio_chunk_generator_iter = audio_chunk_generator( bos_input=bos_input, tokenizer=self.processor.tokenizer, generate_audio=generate_audio, tts_sampling_params=tts_sampling_params, max_new_tokens=max_new_tokens, do_sample=do_sample, teacher_forcing=teacher_forcing, teacher_forcing_text=teacher_forcing_text, **kwargs, ) if generate_audio: if self.tts.config.audio_tokenizer_type == "s3tokenizer_step_audio": self.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.token2wav_cache["flow_cache_base"]) self.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive( self.token2wav_cache["hift_cache_base"] ) # pre-insert 3-5 prefix 4218 silence tokens, each token corresponds to 0.04s, # adding 5 tokens means introducing 0.2s of silence buffer = [4218] * 3 pre_lookahead = 3 CHUNK_SIZE = 25 chunk_idx = 0 prev_text_len = 0 # track text position for streaming text output for audio_token_chunk, is_last_audio_chunk in audio_chunk_generator_iter: if audio_token_chunk is None: break buffer += audio_token_chunk.reshape(-1).tolist() if len(buffer) >= CHUNK_SIZE + pre_lookahead: waveform_chunk = self.tts.audio_tokenizer.stream( buffer[: CHUNK_SIZE + pre_lookahead], prompt_wav=None, last_chunk=is_last_audio_chunk, return_waveform=True, ) waveform_chunk = torch.from_numpy(waveform_chunk) # get new text chunk corresponding to this waveform # Decode from accumulated token IDs to avoid UTF-8 multi-byte truncation new_text = "" if hasattr(self, "_streaming_generated_token_ids"): current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids) # Filter out trailing replacement characters (incomplete UTF-8 sequences) safe_end = len(current_text) while safe_end > 0 and current_text[safe_end - 1] == "\ufffd": safe_end -= 1 safe_text = current_text[:safe_end] new_text = safe_text[prev_text_len:] prev_text_len = len(safe_text) yield waveform_chunk, new_text buffer = buffer[CHUNK_SIZE:] chunk_idx += 1 # flush rest if len(buffer) > 0: waveform_chunk = self.tts.audio_tokenizer.stream( buffer, prompt_wav=None, last_chunk=True, return_waveform=True, ) waveform_chunk = torch.from_numpy(waveform_chunk) # get remaining new text for the final chunk # Final chunk: decode all remaining text without filtering new_text = "" if hasattr(self, "_streaming_generated_token_ids"): current_text = self.processor.tokenizer.decode(self._streaming_generated_token_ids) new_text = current_text[prev_text_len:] prev_text_len = len(current_text) yield waveform_chunk, new_text # maybe the buffer is empty, and text is not empty, should we flush text without wave? else: raise NotImplementedError(f"not supported audio tokenizer: {self.tts.config.audio_tokenizer_type}") else: # For text-only generation, decode tokens and handle partial multi-byte characters yield from streaming_token_decoder( audio_chunk_generator_iter, self.processor.tokenizer, skip_special_tokens=False, ) def as_duplex(self, device: Optional[str] = None, **kwargs) -> "MiniCPMODuplex": """Convert this MiniCPMO instance to MiniCPMODuplex for full-duplex streaming.""" return MiniCPMODuplex.from_existing_model( model=self, device=device, **kwargs, ) class MiniCPMODuplex: """MiniCPMODuplex model with full-duplex streaming capabilities. This is a wrapper class that provides duplex streaming functionality. Use MiniCPMO.as_duplex() to create from an existing model without reloading. """ # Default duplex parameters _default_duplex_params = { "generate_audio": True, "ls_mode": "explicit", "max_new_speak_tokens_per_chunk": 20, "text_repetition_penalty": 1.05, "temperature": 0.7, "top_k": 100, "top_p": 0.8, "text_repetition_window_size": 512, "listen_prob_scale": 1.0, "force_listen_count": 0, "tts_temperature": 0.8, "tts_repetition_penalty": 1.05, "enable_float16": False, "n_timesteps": 10, "chunk_ms": 1000, "first_chunk_ms": 1035, "cnn_redundancy_ms": 20, "sample_rate": 16000, "sliding_window_mode": "off", "basic_window_high_tokens": 8000, "basic_window_low_tokens": 6000, "context_previous_max_tokens": 500, "context_max_units": 24, } @classmethod def from_existing_model( cls, model: "MiniCPMO", device: Optional[str] = None, **kwargs, ) -> "MiniCPMODuplex": """Create MiniCPMODuplex from an existing MiniCPMO instance.""" # Create instance without calling __init__ instance = cls.__new__(cls) instance.name_or_path = getattr(model.config, "_name_or_path", "") # Get default params helper def get_param(name): if name in kwargs: return kwargs[name] return cls._default_duplex_params.get(name) instance.generate_audio = get_param("generate_audio") instance.ls_mode = get_param("ls_mode") # Determine device if device is not None: instance.device = device else: try: instance.device = str(next(model.parameters()).device) except StopIteration: instance.device = "cuda" # Reuse the existing model - THIS IS THE KEY: no reloading! instance.model = model instance.processor = getattr(model, "processor", None) instance.tokenizer = getattr(instance.processor, "tokenizer", None) if instance.processor else None if instance.tokenizer is None: from transformers import AutoTokenizer instance.tokenizer = AutoTokenizer.from_pretrained(instance.name_or_path, trust_remote_code=True) if instance.processor is None: from .processing_minicpmo import MiniCPMOProcessor instance.processor = MiniCPMOProcessor.from_pretrained(instance.name_or_path, trust_remote_code=True) instance.processor.tokenizer = instance.tokenizer # Ensure model has processor reference (same as __init__) instance.model.processor = instance.processor # Initialize TTS (same as __init__) enable_float16 = get_param("enable_float16") n_timesteps = get_param("n_timesteps") instance.model.init_tts(enable_float16=enable_float16, n_timesteps=n_timesteps) instance.break_event = threading.Event() instance.session_stop_event = threading.Event() # LLM generation config instance.max_new_speak_tokens_per_chunk = get_param("max_new_speak_tokens_per_chunk") instance.text_repetition_penalty = get_param("text_repetition_penalty") instance.temperature = get_param("temperature") instance.top_k = get_param("top_k") instance.top_p = get_param("top_p") instance.text_repetition_window_size = get_param("text_repetition_window_size") instance.listen_prob_scale = get_param("listen_prob_scale") instance.force_listen_count = get_param("force_listen_count") # TTS generation config tts_temp_value = get_param("tts_temperature") instance.tts_temperature = torch.tensor([tts_temp_value], dtype=torch.float, device=instance.device) instance.tts_repetition_penalty = get_param("tts_repetition_penalty") # Stream config instance.CHUNK_MS = get_param("chunk_ms") instance.FIRST_CHUNK_MS = get_param("first_chunk_ms") instance.CNN_REDUNDANCY_MS = get_param("cnn_redundancy_ms") instance.SAMPLE_RATE = get_param("sample_rate") instance.model.CHUNK_MS = instance.CHUNK_MS instance.model.FIRST_CHUNK_MS = instance.FIRST_CHUNK_MS instance.model.CNN_REDUNDANCY_MS = instance.CNN_REDUNDANCY_MS instance.model.SAMPLE_RATE = instance.SAMPLE_RATE # Special tokens instance.unit_token_id = instance.tokenizer.convert_tokens_to_ids("") instance.image_start_token_id = instance.tokenizer.convert_tokens_to_ids("") instance.image_end_token_id = instance.tokenizer.convert_tokens_to_ids("") instance.slice_start_token_id = instance.tokenizer.convert_tokens_to_ids("") instance.slice_end_token_id = instance.tokenizer.convert_tokens_to_ids("") instance.listen_token_id = instance.tokenizer.convert_tokens_to_ids("<|listen|>") instance.speak_token_id = instance.tokenizer.convert_tokens_to_ids("<|speak|>") instance.tts_bos_token_id = instance.tokenizer.convert_tokens_to_ids("<|tts_bos|>") instance.tts_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|tts_eos|>") instance.chunk_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|chunk_eos|>") instance.chunk_tts_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>") instance.turn_eos_token_id = instance.tokenizer.convert_tokens_to_ids("<|turn_eos|>") instance.chunk_terminator_token_ids = [ instance.listen_token_id, instance.chunk_eos_token_id, instance.chunk_tts_eos_token_id, ] instance.turn_terminator_token_ids = [instance.turn_eos_token_id] instance.chunk_speak_token_ids = [instance.speak_token_id] instance.tts_pad_id = instance.tokenizer.convert_tokens_to_ids("<|tts_pad|>") bad_token_ids = getattr(instance.tokenizer, "bad_token_ids", []) instance.forbidden_token_ids = [instance.tts_pad_id] + list(bad_token_ids) from .utils import StreamDecoder instance.decoder = StreamDecoder( llm=instance.model.llm, tokenizer=instance.tokenizer, forbidden_token_ids=instance.forbidden_token_ids ) # Sliding window config sliding_window_mode = get_param("sliding_window_mode") basic_window_high_tokens = get_param("basic_window_high_tokens") basic_window_low_tokens = get_param("basic_window_low_tokens") context_previous_max_tokens = get_param("context_previous_max_tokens") context_max_units = get_param("context_max_units") instance.decoder.set_window_config( DuplexWindowConfig( sliding_window_mode=sliding_window_mode, basic_window_high_tokens=basic_window_high_tokens, basic_window_low_tokens=basic_window_low_tokens, context_previous_max_tokens=context_previous_max_tokens, context_max_units=context_max_units, ) ) window_enabled = sliding_window_mode != "off" instance.decoder.set_window_enabled(window_enabled) instance.tts_logits_processors = None instance.tts_eos_token = None if instance.generate_audio: instance.tts_logits_processors = gen_logits( num_code=instance.model.tts.config.num_audio_tokens, repetition_penalty=instance.tts_repetition_penalty, ) instance.tts_eos_token = torch.tensor( [instance.model.tts.config.num_audio_tokens - 1], dtype=torch.long, device=instance.device, ) instance._reset_streaming_state() return instance def set_break_event(self): self.break_event.set() def clear_break_event(self): self.break_event.clear() def set_session_stop(self): self.session_stop_event.set() self.break_event.set() def clear_session_stop(self): self.session_stop_event.clear() def is_break_set(self) -> bool: return self.break_event.is_set() def is_session_stop_set(self) -> bool: return self.session_stop_event.is_set() def _init_token2wav_cache(self, prompt_wav_path: str): self.model.tts.audio_tokenizer.cache = None flow_cache, hift_cache = self.model.tts.audio_tokenizer.set_stream_cache(prompt_wav_path) self.flow_cache_base = torch_clone_recursive(flow_cache) self.hift_cache_base = torch_clone_recursive(hift_cache) self.pre_lookahead = int(self.model.tts.audio_tokenizer.flow.pre_lookahead_len) self.token2wav_initialized = True def _reset_token2wav_for_new_turn(self): if self.token2wav_initialized: self.model.tts.audio_tokenizer.stream_cache = torch_clone_recursive(self.flow_cache_base) self.model.tts.audio_tokenizer.hift_cache_dict = torch_clone_recursive(self.hift_cache_base) self.token2wav_buffer = [4218] * 3 # silence token prefix def _reset_streaming_state(self): self.audio_chunk_idx = 0 self.current_turn_ended = True self.speak_count = 0 self.res_ids = [] self.total_ids = [] self.total_hidden = [] # TTS state self.tts_text_start_pos = 0 self.tts_past_key_values = None self.tts_current_turn_start_time = None # token2wav state self.token2wav_initialized = False self.token2wav_buffer = [] self.flow_cache_base = None self.hift_cache_base = None # Audio prefill state self.audio_buffer = np.array([], dtype=np.float32) self.pending_logits: Optional[torch.Tensor] = None self.current_mode: Optional[str] = None # Force listen state self._streaming_generate_count = 0 # Schema tracking: record the complete prefill + generate token sequence # prefill_schema_tokens: each element is a list of prefill tokens for a unit # format: [[unit0_prefill_tokens], [unit1_prefill_tokens], ...] self.prefill_schema_tokens = [] self._current_unit_prefill_tokens = [] def prepare( self, prefix_system_prompt: Optional[str] = None, ref_audio: Optional[np.ndarray] = None, prompt_wav_path: Optional[str] = None, context_previous_marker: str = "\n\nprevious: ", **kwargs, ): prefix_system_prompt = prefix_system_prompt or "Streaming Omni Conversation." prefix_system_prompt = "<|im_start|>system\n" + prefix_system_prompt suffix_system_prompt = "<|im_end|>" if isinstance(ref_audio, np.ndarray): prefix_system_prompt += "\n<|audio_start|>" suffix_system_prompt = "<|audio_end|>" + suffix_system_prompt self.clear_break_event() self.clear_session_stop() self._reset_streaming_state() self.decoder.reset() self.model.init_streaming_processor() if prompt_wav_path is not None and prompt_wav_path and self.generate_audio: self._init_token2wav_cache(prompt_wav_path) self._reset_token2wav_for_new_turn() # Prefill system prompt prefix if prefix_system_prompt: tokens = self.tokenizer.encode(prefix_system_prompt, add_special_tokens=False) for token_id in tokens: self.decoder.feed(self.decoder.embed_token(token_id)) # Prefill reference audio if ref_audio is not None: data = self.processor.process_audio([ref_audio]) embeds_nested = self.model.get_audio_embedding(data, chunk_length=self.model.config.audio_chunk_length) embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) if embeds_nested else None if embeds is not None: self.decoder.feed(embeds) # register system prompt protection length (protect this part from being removed when sliding window is enabled) if prefix_system_prompt or suffix_system_prompt or ref_audio is not None: if self.decoder._window_config.sliding_window_mode == "context": # Context preserve mode: # initial layout: [prefix] [suffix] [units...] # after the first sliding window: [prefix] [context_previous_marker + content] [suffix] [units...] # register prefix length first, then feed suffix self._prefix_system_prompt = prefix_system_prompt self._suffix_system_prompt = suffix_system_prompt self._ref_audio = ref_audio suffix_token_ids = [] if suffix_system_prompt: suffix_token_ids = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False) # register (when cache only has prefix, no suffix, no previous) self.decoder.register_system_prompt_with_context( suffix_token_ids=suffix_token_ids, context_previous_marker=context_previous_marker, # dynamically added after the first sliding window ) # now feed suffix for token_id in suffix_token_ids: self.decoder.feed(self.decoder.embed_token(token_id)) else: # non-context preserve mode: first feed suffix, then register total length if suffix_system_prompt: tokens = self.tokenizer.encode(suffix_system_prompt, add_special_tokens=False) for token_id in tokens: self.decoder.feed(self.decoder.embed_token(token_id)) self.decoder.register_system_prompt() if prefix_system_prompt or suffix_system_prompt: if ref_audio is not None: full_prompt = (prefix_system_prompt or "") + "[audio embedding]" + (suffix_system_prompt or "") else: full_prompt = (prefix_system_prompt or "") + (suffix_system_prompt or "") return full_prompt return "" @torch.no_grad() def streaming_prefill( self, audio_waveform: Optional[np.ndarray] = None, frame_list: Optional[list] = None, text_list: Optional[list] = None, max_slice_nums: Union[int, List[int]] = 1, batch_vision_feed: bool = False, ): """Streaming prefill - called once per second, processing audio/video data Args: audio_waveform: audio waveform data frame_list: image frame list text_list: text max_slice_nums: maximum number of slices for HD image encoding (default 1, no slicing) Can be an int (same for all images) or a list matching frame_list length batch_vision_feed: if True, batch all vision embeddings into a single feed call for better performance. if False (default), feed each embedding individually (original behavior). Process: 0. determine mode based on input: AUDIO / VISION / OMNI 1. feed token 2. get and feed image embed (if frame_list) - return pending logits in VISION MODE 3. get and feed audio embed (if audio_waveform) - return pending logits in AUDIO/OMNI MODE Returns: dict with keys: - success: bool - cost_vision_process: float (image processing time) - cost_vision_embed: float (vision embedding time) - cost_vision_feed: float (vision feed time) - cost_audio_process: float (audio processing time) - cost_audio_embed: float (audio embedding time) - cost_audio_feed: float (audio feed time) - cost_all: float (total time) """ start_time = time.time() cost_vision_process = 0.0 cost_vision_embed = 0.0 cost_vision_feed = 0.0 cost_audio_process = 0.0 cost_audio_embed = 0.0 cost_audio_feed = 0.0 def _make_result(success, reasons=""): reason = reasons if isinstance(reasons, list): reason = "; ".join(reasons) return { "success": success, "reason": reason, "cost_vision_process": cost_vision_process, "cost_vision_embed": cost_vision_embed, "cost_vision_feed": cost_vision_feed, "cost_audio_process": cost_audio_process, "cost_audio_embed": cost_audio_embed, "cost_audio_feed": cost_audio_feed, "cost_all": time.time() - start_time, } if self.is_session_stop_set() or self.is_break_set(): return _make_result(False) has_frames = frame_list is not None and len(frame_list) > 0 has_audio = audio_waveform is not None and len(audio_waveform) > 0 has_text = text_list is not None and len(text_list) > 0 if has_frames and has_audio: mode = "OMNI" elif has_frames: mode = "VISION" elif has_audio: mode = "AUDIO" elif has_text: mode = "TEXT" else: return _make_result(False) self.pending_logits = None # sliding window: record unit start position self.decoder.register_unit_start() # Schema tracking: start new unit, record prefill tokens self._current_unit_prefill_tokens = [] # Step 1: Feed token self.decoder.feed(self.decoder.embed_token(self.unit_token_id)) self._current_unit_prefill_tokens.append(self.unit_token_id) # Step 2: process image if has_frames: t0 = time.time() # normalize max_slice_nums to a list matching frame_list length if isinstance(max_slice_nums, int): max_slice_nums_list = [max_slice_nums] * len(frame_list) else: max_slice_nums_list = list(max_slice_nums) if len(max_slice_nums_list) != len(frame_list): raise ValueError( f"max_slice_nums list length ({len(max_slice_nums_list)}) " f"must match frame_list length ({len(frame_list)})" ) # check if all max_slice_nums are the same (can use batch processing) all_same = len(set(max_slice_nums_list)) == 1 if all_same: # all images use the same max_slice_nums, use batch processing processed_frames = self.processor.process_image(frame_list, max_slice_nums=max_slice_nums_list[0]) if self.device: processed_frames = processed_frames.to(self.device) else: # different max_slice_nums per image, process individually and merge all_pixel_values = [] all_tgt_sizes = [] for frame, max_slices in zip(frame_list, max_slice_nums_list): pf = self.processor.process_image([frame], max_slice_nums=max_slices) if self.device: pf = pf.to(self.device) # pf["pixel_values"][0] is the list of slices for this image all_pixel_values.extend(pf["pixel_values"][0]) # pf["tgt_sizes"][0] is the array of target sizes for this image's slices if hasattr(pf["tgt_sizes"][0], "tolist"): all_tgt_sizes.extend(pf["tgt_sizes"][0].tolist()) else: all_tgt_sizes.extend(list(pf["tgt_sizes"][0])) # reconstruct processed_frames with merged data processed_frames = { "pixel_values": [all_pixel_values], "tgt_sizes": [torch.tensor(all_tgt_sizes) if all_tgt_sizes else []], } cost_vision_process = time.time() - t0 t0 = time.time() # get vision embeddings for all images (each may have multiple slices) # vision_hidden_states is a list, one entry per input image # each entry contains embeddings for [source_image, slice_1, slice_2, ...] vision_hidden_states = self.model.get_vision_embedding(processed_frames) cost_vision_embed = time.time() - t0 if vision_hidden_states is not None and len(vision_hidden_states) > 0: t0 = time.time() # vision_hidden_states[0] contains ALL slices from ALL images (flattened) # shape: [total_slices, 64, D] where total_slices = sum of slices across all images # we need to know how many slices each image has to correctly group them # calculate slice counts for each image using get_sliced_grid (lightweight, no actual slicing) slice_counts = [] # e.g., [5, 9] means img1 has 5 slices (1 source + 4 HD), img2 has 9 slices for frame_idx, frame in enumerate(frame_list): max_slices = max_slice_nums_list[frame_idx] if hasattr(frame, "size"): # get_sliced_grid returns [M, N] grid or None if no slicing needed # total images = 1 (source) + M * N (HD slices) grid = self.processor.image_processor.get_sliced_grid( frame.size, max_slices, nerver_split=False ) if grid is not None: slice_counts.append(1 + grid[0] * grid[1]) # 1 source + M*N slices else: slice_counts.append(1) # no slicing, only source image else: slice_counts.append(1) # default: single image, no slicing # get the flattened embeddings tensor # vision_hidden_states is a list with one element (the batch) # vision_hidden_states[0] shape: [total_slices, 64, D] all_embeds = vision_hidden_states[0] # collect all feed operations first, then execute # this allows us to identify the last token for VISION mode logits feed_operations = [] # List of (embed, is_last_for_vision_mode, token_id_or_none) embed_idx = 0 # current index in all_embeds for img_idx, num_slices in enumerate(slice_counts): if num_slices == 0: continue # the first embedding is always the source image (downsampled overview) # Feed token feed_operations.append( (self.decoder.embed_token(self.image_start_token_id), False, self.image_start_token_id) ) # Feed source image embedding (shape: [64, D]) - use None to indicate embedding feed_operations.append((all_embeds[embed_idx], False, None)) # Feed token feed_operations.append( (self.decoder.embed_token(self.image_end_token_id), False, self.image_end_token_id) ) embed_idx += 1 # remaining embeddings are HD slices (if num_slices > 1) if num_slices > 1: for slice_i in range(1, num_slices): # Feed token feed_operations.append( (self.decoder.embed_token(self.slice_start_token_id), False, self.slice_start_token_id) ) # Feed slice embedding (shape: [64, D]) feed_operations.append((all_embeds[embed_idx], False, None)) # Feed token feed_operations.append( (self.decoder.embed_token(self.slice_end_token_id), False, self.slice_end_token_id) ) embed_idx += 1 # mark the last operation for VISION mode logits if feed_operations: feed_operations[-1] = (feed_operations[-1][0], True, feed_operations[-1][2]) # execute feed operations if batch_vision_feed and feed_operations: # batch mode: concatenate all embeddings and feed at once # this reduces LLM forward passes from N to 1 # # NOTE: batch mode may have slight numerical differences compared to for-loop mode # due to floating-point precision in attention computation. This is expected behavior # for causal attention with incremental vs batch computation. all_embeds_list = [] for embed, is_last, token_id in feed_operations: # ensure all embeddings have shape [L, H] if embed.dim() == 1: embed = embed.unsqueeze(0) all_embeds_list.append(embed) # concatenate all embeddings # torch.cat requires consistent dtype; embeddings should already be same dtype all_embeds_to_feed = torch.cat(all_embeds_list, dim=0) # [total_L, H] if mode == "VISION": # vision mode needs logits from the last token self.pending_logits, _ = self.decoder.feed(all_embeds_to_feed, return_logits=True) else: # omni mode: just feed, wait for audio to get logits self.decoder.feed(all_embeds_to_feed) # schema tracking: record all token IDs and embedding markers for embed, is_last, token_id in feed_operations: if token_id is not None: self._current_unit_prefill_tokens.append(token_id) else: embed_dim = embed.shape[0] if len(embed.shape) > 1 else 1 self._current_unit_prefill_tokens.append(("img", embed_dim)) else: for embed, is_last, token_id in feed_operations: if mode == "VISION" and is_last: # get logits from the last token self.pending_logits, _ = self.decoder.feed(embed, return_logits=True) else: self.decoder.feed(embed) # schema tracking: record token ID or embedding marker if token_id is not None: self._current_unit_prefill_tokens.append(token_id) else: # use tuple to mark image embedding: ("img", dim) embed_dim = embed.shape[0] if len(embed.shape) > 1 else 1 self._current_unit_prefill_tokens.append(("img", embed_dim)) # for omni mode, no pending logits needed here (wait for audio) cost_vision_feed = time.time() - t0 # Step 3: process audio (if any) if has_audio: # accumulate audio to buffer self.audio_buffer = np.concatenate([self.audio_buffer, audio_waveform]) # calculate required audio length if self.audio_chunk_idx == 0: required_samples = int(self.FIRST_CHUNK_MS * self.SAMPLE_RATE / 1000) if len(self.audio_buffer) < required_samples: padding_samples = required_samples - len(self.audio_buffer) padding = np.zeros(padding_samples, dtype=np.float32) self.audio_buffer = np.concatenate([padding, self.audio_buffer]) else: required_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000) need_samples = self.processor.get_streaming_chunk_size() if len(self.audio_buffer) < need_samples: return _make_result( False, f"audio not enough: need {need_samples} samples, only {len(self.audio_buffer)}" ) audio_chunk = self.audio_buffer[:need_samples] t0 = time.time() batch_feature = self.processor.process_audio_streaming( audio_chunk, reset=False, return_batch_feature=True, ) if batch_feature is None or batch_feature.audio_features.shape[-1] == 0: return _make_result(False, "streaming audio processing returned empty") # metadata batch_feature.chunk_idx = self.audio_chunk_idx batch_feature.use_extra_context = True batch_feature.prefix_extra_frames = 0 if self.audio_chunk_idx == 0 else 2 batch_feature.suffix_extra_frames = 2 batch_feature = batch_feature.to(self.device) cost_audio_process = time.time() - t0 t0 = time.time() embeds_nested = self.model.get_audio_embedding_streaming( batch_feature, use_extra_context=batch_feature.use_extra_context, prefix_extra_frames=batch_feature.prefix_extra_frames, suffix_extra_frames=batch_feature.suffix_extra_frames, ) audio_embeds = torch.cat([t for g in embeds_nested for t in g], dim=0) cost_audio_embed = time.time() - t0 t0 = time.time() self.pending_logits, _ = self.decoder.feed(audio_embeds, return_logits=True) cost_audio_feed = time.time() - t0 # schema tracking: use tuple to mark audio embedding: ("audio", dim) embed_dim = audio_embeds.shape[0] if len(audio_embeds.shape) > 1 else 1 self._current_unit_prefill_tokens.append(("audio", embed_dim)) if self.audio_chunk_idx == 0: cfg = self.processor._streaming_mel_processor.get_config() consumed_ms = int(cfg.get("effective_first_chunk_ms", self.FIRST_CHUNK_MS)) consumed_samples = int(consumed_ms * self.SAMPLE_RATE / 1000) else: consumed_samples = int(self.CHUNK_MS * self.SAMPLE_RATE / 1000) self.audio_buffer = self.audio_buffer[consumed_samples:] self.audio_chunk_idx += 1 # Step 4: process text if has_text: # concatenate all text items text_content = "".join(text_list) if isinstance(text_list, list) else str(text_list) # tokenize text text_token_ids = self.tokenizer.encode(text_content, add_special_tokens=False) if len(text_token_ids) > 0: # get token embeddings text_token_ids_tensor = torch.tensor(text_token_ids, dtype=torch.long, device=self.device) text_embeds = self.decoder.embed_token(text_token_ids_tensor) # feed to decoder if mode == "TEXT": # text-only mode: get logits from the last token self.pending_logits, _ = self.decoder.feed(text_embeds, return_logits=True) else: # mixed mode: just feed, let other modality get logits self.decoder.feed(text_embeds) # schema tracking: record text token IDs for token_id in text_token_ids: self._current_unit_prefill_tokens.append(token_id) self.current_mode = mode if mode == "VISION": self.audio_chunk_idx += 1 # schema tracking: save current unit's prefill tokens self.prefill_schema_tokens.append(self._current_unit_prefill_tokens) return _make_result(True) @torch.no_grad() def streaming_generate( self, prompt_wav_path=None, max_new_speak_tokens_per_chunk=20, decode_mode: str = "sampling", temperature=0.7, top_k=100, top_p=0.8, listen_prob_scale=1.0, listen_top_k=None, text_repetition_penalty=1.05, text_repetition_window_size=512, ): start_time = time.time() if self.is_session_stop_set() or self.is_break_set(): return { "is_listen": True, "text": "", "audio_waveform": self._generate_silence_waveform(), "end_of_turn": True, "current_time": self.audio_chunk_idx, "cost_llm": 0.0, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": 0, "n_tts_tokens": 0, } # check if there are pending logits to process if not hasattr(self, "pending_logits") or self.pending_logits is None: return { "is_listen": True, "text": "", "audio_waveform": self._generate_silence_waveform(), "end_of_turn": False, "current_time": self.audio_chunk_idx, "cost_llm": 0.0, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": 0, "n_tts_tokens": 0, } # use pending logits generated in streaming_prefill logits = self.pending_logits self.pending_logits = None # Force listen: check if we should force listen for first N calls force_listen = self._streaming_generate_count < self.force_listen_count self._streaming_generate_count += 1 total_hidden_in_unit = [] total_ids_in_unit = [] current_time = self.audio_chunk_idx is_listen = False end_of_turn = False llm_start_time = time.time() for j in range(max_new_speak_tokens_per_chunk): if j == max_new_speak_tokens_per_chunk - 1: if self.ls_mode == "explicit": self.decoder.feed(self.decoder.embed_token(self.chunk_eos_token_id)) self.total_ids.append(self.chunk_eos_token_id) break if force_listen: last_id = torch.tensor([self.listen_token_id], dtype=torch.long, device=self.device) else: last_id = self.decoder.decode( logits=logits, mode=decode_mode, temperature=temperature, top_k=top_k, top_p=top_p, listen_top_k=listen_top_k, listen_prob_scale=listen_prob_scale, text_repetition_penalty=text_repetition_penalty, text_repetition_window_size=text_repetition_window_size, ) # if current turn not ended, not allowed to listen (only check when not force_listen) if last_id.item() == self.listen_token_id and (not self.current_turn_ended): last_id = torch.tensor([self.tts_bos_token_id], dtype=torch.long, device=self.device) self.total_ids.append(last_id.item()) is_listen = last_id.item() == self.listen_token_id # termination condition detection if last_id.item() in self.chunk_terminator_token_ids: if self.ls_mode == "explicit": logits, _ = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True) break else: # normal speak self.current_turn_ended = False if last_id.item() in self.chunk_speak_token_ids: pass else: self.res_ids.append(last_id.item()) self.speak_count += 1 logits, hidden = self.decoder.feed(self.decoder.embed_token(last_id.item()), return_logits=True) assert len(hidden.shape) == 3 assert hidden.shape[0] == 1 assert hidden.shape[1] == 1 end_of_turn = last_id.item() in self.turn_terminator_token_ids if end_of_turn: self.current_turn_ended = True if j != 0: total_hidden_in_unit.append([last_id.item(), hidden, end_of_turn]) total_ids_in_unit.append(last_id.item()) # Prefill token unit_end_id = self.tokenizer.convert_tokens_to_ids("") self.decoder.feed(self.decoder.embed_token(unit_end_id)) self.total_ids.append(unit_end_id) # calculate generated text (for sliding window context preserve, filter out special tokens) generated_text = self.tokenizer.decode(total_ids_in_unit, skip_special_tokens=True) if total_ids_in_unit else "" # sliding window: register unit end, and check if sliding window is needed input_type = self.current_mode.lower() if self.current_mode else "audio" self.decoder.register_unit_end( input_type=input_type, generated_tokens=total_ids_in_unit, is_listen=is_listen, generated_text=generated_text, ) # select sliding window method based on sliding window mode if self.decoder._window_config.sliding_window_mode == "context": self.decoder.enforce_window_with_context() elif self.decoder._window_config.sliding_window_mode == "basic": self.decoder.enforce_window() llm_end_time = time.time() if is_listen: self.total_hidden.append([]) return { "is_listen": True, "text": "", "audio_waveform": self._generate_silence_waveform(), "end_of_turn": False, "current_time": current_time, "cost_llm": llm_end_time - llm_start_time, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": len(total_ids_in_unit), "n_tts_tokens": 0, } self.total_hidden.append(total_hidden_in_unit) text = generated_text # reuse already calculated text if not self.generate_audio: return { "is_listen": False, "text": text, "audio_waveform": None, "end_of_turn": end_of_turn, "current_time": current_time, "cost_llm": llm_end_time - llm_start_time, "cost_tts_prep": 0.0, "cost_tts": 0.0, "cost_token2wav": 0.0, "cost_all": time.time() - start_time, "n_tokens": len(total_ids_in_unit), "n_tts_tokens": 0, } # TTS generate tts_start_time = time.time() tts_prep_start_time = time.time() tts_condition = self._convert_results_to_tts_input(total_hidden_in_unit) tts_prep_end_time = time.time() max_token_per_chunk = 25 + 1 min_token_per_chunk = 25 + 1 if end_of_turn: min_token_per_chunk = 0 force_flush = False if self.tts_text_start_pos == 0: # this is the start of the turn min_token_per_chunk = 0 # allow decoding <1s audio force_flush = True if self.tts_current_turn_start_time is None: self.tts_current_turn_start_time = current_time new_tokens, old_kv = self.model.tts.generate_chunk( inputs_embeds=tts_condition, temperature=self.tts_temperature, repetition_penalty=self.tts_repetition_penalty, eos_token=self.tts_eos_token, force_no_stop=False, max_new_token=max_token_per_chunk, min_new_tokens=min_token_per_chunk, past_key_values=self.tts_past_key_values, logits_processors=self.tts_logits_processors, text_start_pos=self.tts_text_start_pos, ) tts_end_time = time.time() # update TTS state (note: token2wav reset must be after audio generation, otherwise tokens in buffer will be lost) if end_of_turn: self.tts_text_start_pos = 0 self.tts_past_key_values = None self.tts_current_turn_start_time = None else: self.tts_past_key_values = old_kv self.tts_text_start_pos += tts_condition.shape[1] + new_tokens.shape[1] # token2wav generation (must be before reset, otherwise tokens in the last but second chunk will be lost) token2wav_start_time = time.time() audio_waveform = self._generate_waveform_from_tokens( new_tokens, prompt_wav_path, end_of_turn, force_flush=force_flush ) token2wav_end_time = time.time() # reset token2wav state after audio generation, ensure all tokens in buffer are processed if end_of_turn: self._reset_token2wav_for_new_turn() end_time = time.time() return { "is_listen": False, "text": text, "audio_waveform": audio_waveform, "end_of_turn": end_of_turn, "current_time": current_time, "cost_llm": llm_end_time - llm_start_time, "cost_tts_prep": tts_prep_end_time - tts_prep_start_time, "cost_tts": tts_end_time - tts_start_time, "cost_token2wav": token2wav_end_time - token2wav_start_time, "cost_all": end_time - start_time, "n_tokens": len(total_ids_in_unit), "n_tts_tokens": new_tokens.numel(), } def get_session_schema(self, include_embeddings: bool = True) -> str: """get complete schema for current session (includes prefill and generate stages) Args: include_embeddings: whether to include embedding placeholders (e.g. [img_embed_64], [audio_embed_50]) Returns: complete schema string, each unit format: [img_embed_64][audio_embed_50]<|listen|or|speak|>generated_content """ if not hasattr(self, "prefill_schema_tokens") or not hasattr(self, "total_ids"): return "" # get token id for splitting generate tokens unit_end_token_id = self.tokenizer.convert_tokens_to_ids("") # split generate tokens into each unit generate_units = [] current_unit = [] for tid in self.total_ids: current_unit.append(tid) if tid == unit_end_token_id: generate_units.append(current_unit) current_unit = [] # build complete schema full_schema_parts = [] num_units = max(len(self.prefill_schema_tokens), len(generate_units)) for unit_idx in range(num_units): unit_schema = "" # prefill part if unit_idx < len(self.prefill_schema_tokens): prefill_tokens = self.prefill_schema_tokens[unit_idx] for item in prefill_tokens: if isinstance(item, tuple): # tuple represents embedding: ("img", dim) or ("audio", dim) embed_type, embed_dim = item if include_embeddings: unit_schema += f"[{embed_type}_embed_{embed_dim}]" else: # normal token ID unit_schema += self.tokenizer.decode([item], skip_special_tokens=False) # generate part if unit_idx < len(generate_units): unit_schema += self.tokenizer.decode(generate_units[unit_idx], skip_special_tokens=False) full_schema_parts.append(unit_schema) return "".join(full_schema_parts) def get_unit_schemas(self, include_embeddings: bool = True) -> list: """get list of schema for each unit Returns: list of schema strings for each unit """ if not hasattr(self, "prefill_schema_tokens") or not hasattr(self, "total_ids"): return [] unit_end_token_id = self.tokenizer.convert_tokens_to_ids("") # split generate tokens into each unit generate_units = [] current_unit = [] for tid in self.total_ids: current_unit.append(tid) if tid == unit_end_token_id: generate_units.append(current_unit) current_unit = [] # build schema for each unit unit_schemas = [] num_units = max(len(self.prefill_schema_tokens), len(generate_units)) for unit_idx in range(num_units): unit_schema = "" # prefill part if unit_idx < len(self.prefill_schema_tokens): prefill_tokens = self.prefill_schema_tokens[unit_idx] for item in prefill_tokens: if isinstance(item, tuple): # tuple represents embedding: ("img", dim) or ("audio", dim) embed_type, embed_dim = item if include_embeddings: unit_schema += f"[{embed_type}_embed_{embed_dim}]" else: # normal token ID unit_schema += self.tokenizer.decode([item], skip_special_tokens=False) # generate part if unit_idx < len(generate_units): unit_schema += self.tokenizer.decode(generate_units[unit_idx], skip_special_tokens=False) unit_schemas.append(unit_schema) return unit_schemas def _convert_results_to_tts_input(self, results): """convert LLM hidden states to TTS input""" if len(results) == 0: audio_bos = self.model.tts.emb_text( torch.tensor( [self.model.tts.audio_bos_token_id], device=self.model.tts.emb_text.weight.device, dtype=torch.long, ) ) return audio_bos.unsqueeze(0) llm_tokens = [] llm_hidden = [] for hidden in results: llm_tokens.append(hidden[0]) llm_hidden.append(hidden[1].squeeze(0)) llm_tokens_tensor = torch.Tensor(llm_tokens).to(self.device, dtype=torch.long) llm_embeds = self.model.tts.emb_text(llm_tokens_tensor) llm_hidden_tensor = torch.cat(llm_hidden, dim=0) llm_hidden_tensor = self.model.tts.projector_semantic(llm_hidden_tensor) llm_hidden_tensor = torch.nn.functional.normalize(llm_hidden_tensor, p=2, dim=-1) tts_embeds = llm_embeds + llm_hidden_tensor audio_bos = self.model.tts.emb_text( torch.tensor( [self.model.tts.audio_bos_token_id], device=self.model.tts.emb_text.weight.device, dtype=torch.long, ) ) tts_embeds = torch.cat([tts_embeds, audio_bos], dim=0) return tts_embeds.unsqueeze(0) def _generate_waveform_from_tokens( self, new_tokens: torch.Tensor, prompt_wav_path: Optional[str], is_last_chunk: bool = False, force_flush: bool = False, ) -> Optional[np.ndarray]: if not self.token2wav_initialized: logger.warning("token2wav_initialized is uninitialized") return None CHUNK_SIZE = 25 token_ids = torch.reshape(new_tokens, (-1,)).tolist() self.token2wav_buffer += token_ids has_chunk_eos = any(tid in self.chunk_terminator_token_ids for tid in token_ids) pcm_bytes_list = [] # process enough tokens # if there is chunk_eos, try to flush more content if has_chunk_eos or force_flush: # when there is chunk_eos, try to flush more content while len(self.token2wav_buffer) >= self.pre_lookahead + 5: # at least keep some lookahead chunk_to_process = min(CHUNK_SIZE + self.pre_lookahead, len(self.token2wav_buffer)) pcm_bytes = self.model.tts.audio_tokenizer.stream( self.token2wav_buffer[:chunk_to_process], prompt_wav=prompt_wav_path, ) pcm_bytes_list.append(pcm_bytes) self.token2wav_buffer = self.token2wav_buffer[min(CHUNK_SIZE, chunk_to_process - self.pre_lookahead) :] else: while len(self.token2wav_buffer) >= CHUNK_SIZE + self.pre_lookahead: pcm_bytes = self.model.tts.audio_tokenizer.stream( self.token2wav_buffer[: CHUNK_SIZE + self.pre_lookahead], prompt_wav=prompt_wav_path, ) pcm_bytes_list.append(pcm_bytes) self.token2wav_buffer = self.token2wav_buffer[CHUNK_SIZE:] # if is the last chunk, flush remaining tokens if is_last_chunk and len(self.token2wav_buffer) > 0: pcm_bytes = self.model.tts.audio_tokenizer.stream( self.token2wav_buffer, prompt_wav=prompt_wav_path, last_chunk=True, ) pcm_bytes_list.append(pcm_bytes) self.token2wav_buffer = [] if not pcm_bytes_list: return None # merge PCM and convert to numpy array (24kHz, int16 -> float32) all_pcm = b"".join(pcm_bytes_list) if len(all_pcm) == 0: return None pcm_np = np.frombuffer(all_pcm, dtype=" np.ndarray: """generate silence waveform (24kHz)""" sample_rate = 24000 num_samples = int(duration_sec * sample_rate) return np.zeros(num_samples, dtype=np.float32) def get_generated_text(self) -> str: return self.tokenizer.decode(self.res_ids) def get_current_time(self) -> int: return self.audio_chunk_idx def as_simplex(self, reset_session: bool = True, reset_token2wav_cache: bool = False) -> "MiniCPMO": """Convert this MiniCPMODuplex instance back to MiniCPMO for simplex mode. Args: reset_session: If True, reset streaming session state (KV cache, etc.). Recommended when switching from duplex to simplex mode. Returns the underlying MiniCPMO model instance without reloading. """ if reset_session: self.model.reset_session(reset_token2wav_cache=reset_token2wav_cache) return self.model def get_2d_sincos_pos_embed(embed_dim, image_size): """ image_size: image_size or (image_height, image_width) return: pos_embed: [image_height, image_width, embed_dim] """ if isinstance(image_size, int): grid_h_size, grid_w_size = image_size, image_size else: grid_h_size, grid_w_size = image_size[0], image_size[1] grid_h = np.arange(grid_h_size, dtype=np.float32) grid_w = np.arange(grid_w_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2) emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) return emb def get_1d_sincos_pos_embed_from_grid_new(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (H, W) out: (H, W, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product emb_sin = np.sin(out) # (H, W, D/2) emb_cos = np.cos(out) # (H, W, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) return emb class Resampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by given learnable queries and 2d sincos pos_emb Outputs: A tensor with the shape of (batch_size, num_queries, embed_dim) """ def __init__( self, num_queries, embed_dim, num_heads, kv_dim=None, norm_layer=partial(nn.LayerNorm, eps=1e-6), adaptive=False, max_size=(70, 70), ): super().__init__() self.num_queries = num_queries self.embed_dim = embed_dim self.num_heads = num_heads self.adaptive = adaptive self.max_size = max_size self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) else: self.kv_proj = nn.Identity() self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.ln_post = norm_layer(embed_dim) self.proj = nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) self._set_2d_pos_cache(self.max_size) def _set_2d_pos_cache(self, max_size, device="cpu"): if is_deepspeed_zero3_enabled(): device = "cuda" pos_embed = torch.from_numpy(get_2d_sincos_pos_embed(self.embed_dim, max_size)).float().to(device) self.register_buffer("pos_embed", pos_embed, persistent=False) def _adjust_pos_cache(self, tgt_sizes, device): max_h = torch.max(tgt_sizes[:, 0]) max_w = torch.max(tgt_sizes[:, 1]) if max_h > self.max_size[0] or max_w > self.max_size[1]: self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])] self._set_2d_pos_cache(self.max_size, device) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, tgt_sizes=None): assert x.shape[0] == tgt_sizes.shape[0] bs = x.shape[0] device = x.device dtype = x.dtype patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] self._adjust_pos_cache(tgt_sizes, device=device) max_patch_len = torch.max(patch_len) key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool, device=device) pos_embed = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i] pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype)) # patches * D key_padding_mask[i, patch_len[i] :] = True pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( 1, 0, 2 ) # BLD => L * B * D x = self.kv_proj(x) # B * L * D x = self.ln_kv(x).permute(1, 0, 2) # L * B * D q = self.ln_q(self.query) # Q * D out = self.attn( self._repeat(q, bs), # Q * B * D x + pos_embed, # L * B * D + L * B * D x, key_padding_mask=key_padding_mask, )[0] # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D x = self.ln_post(x) x = x @ self.proj return x def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) class MiniCPMWhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.d_model try: # compatible old transformers from transformers.models.whisper.modeling_whisper import WHISPER_ATTENTION_CLASSES self.self_attn = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, layer_idx=layer_idx, ) except: from transformers.models.whisper.modeling_whisper import WhisperAttention self.self_attn = WhisperAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = False, ) -> torch.Tensor: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, past_key_values = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, past_key_value=past_key_values, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) if use_cache: outputs += (past_key_values,) return outputs # Copied from from transformers.models.whisper.modeling_whisper.WhisperEncoder and add use_cache for streaming inference class MiniCPMWhisperEncoder(WhisperEncoder): def __init__(self, config: WhisperConfig): super().__init__(config) self.layers = nn.ModuleList( [MiniCPMWhisperEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)] ) def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = None, use_extra_context: Optional[bool] = False, prefix_extra_frames: Optional[int] = 1, suffix_extra_frames: Optional[int] = 1, cnn_min_length: Optional[int] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Ignore copy input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) # Optional: pad short input to minimum length for CNN computation consistency original_length = input_features.shape[2] padded_for_cnn = False if cnn_min_length is not None and original_length < cnn_min_length: padded_features = torch.zeros( input_features.shape[0], input_features.shape[1], cnn_min_length, dtype=input_features.dtype, device=input_features.device, ) padded_features[:, :, :original_length] = input_features input_features = padded_features padded_for_cnn = True conv1_output = self.conv1(input_features) inputs_embeds = nn.functional.gelu(conv1_output) conv2_output = self.conv2(inputs_embeds) inputs_embeds = nn.functional.gelu(conv2_output) # If padding was done before, now need to remove the effect of padding if padded_for_cnn: # Conv1: stride=1, output length=input length # Conv2: stride=2, output length=(input length+1)//2 actual_cnn_output_length = (original_length + 1) // 2 inputs_embeds = inputs_embeds[:, :, :actual_cnn_output_length] # If extra context is used, CNN operations need to remove redundant frames # conv2 stride=2, so the redundant frames in the input will be halved (upward rounding) if use_extra_context: # Input has prefix_extra_frames prefix frames and suffix_extra_frames suffix frames # conv2 stride=2, output length = ceil(input length / 2) # For 2 redundant frames, the output is 1 frame (ceil(2/2) = 1) prefix_to_remove = (prefix_extra_frames + 1) // 2 if prefix_extra_frames > 0 else 0 suffix_to_remove = (suffix_extra_frames + 1) // 2 if suffix_extra_frames > 0 else 0 # Remove redundant frames before and after (batch, channels, time) if prefix_to_remove > 0: inputs_embeds = inputs_embeds[:, :, prefix_to_remove:] if 0 < suffix_to_remove < inputs_embeds.shape[2]: inputs_embeds = inputs_embeds[:, :, :-suffix_to_remove] inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight past_key_values_length = 0 if use_cache: if past_key_values is None: past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) elif isinstance(past_key_values, list): past_key_values = EncoderDecoderCache(DynamicCache.from_legacy_cache(past_key_values), DynamicCache()) elif isinstance(past_key_values, DynamicCache): past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) else: pass past_key_values_length = past_key_values.self_attention_cache.get_usable_length(inputs_embeds.shape[1]) if inputs_embeds.shape[1] + past_key_values_length > embed_pos.shape[0]: logger.warning("seems the audio is longer than 30s. repeating the last part of the audio") embed_pos_front = embed_pos[past_key_values_length:, :] embed_pos = torch.cat( ( embed_pos_front, torch.repeat_interleave( embed_pos[-1, :].unsqueeze(0), inputs_embeds.shape[1] - embed_pos.shape[0] + past_key_values_length, dim=0, ), ) ) else: embed_pos = embed_pos[past_key_values_length : inputs_embeds.shape[1] + past_key_values_length, :] else: embed_pos = embed_pos[: inputs_embeds.shape[1], :] hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True # Ignore copy if to_drop: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), output_attentions, past_key_values, use_cache, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_encoder_cache = layer_outputs[2 if output_attentions else 1] else: next_encoder_cache = None if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: result = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return result result = BaseModelOutputWithPast( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, past_key_values=next_encoder_cache, ) return result class MultiModalProjector(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True) self.relu = nn.ReLU() self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True) def forward(self, audio_features): hidden_states = self.relu(self.linear1(audio_features)) hidden_states = self.linear2(hidden_states) return hidden_states class MiniCPMMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.in_dim = config.llm_hidden_size self.out_dim = config.hidden_size self.intermediate_size = config.llm_intermediate_size self.gate_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True) self.up_proj = nn.Linear(self.in_dim, self.intermediate_size, bias=True) self.down_proj = nn.Linear(self.intermediate_size, self.out_dim, bias=True) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj @dataclass class MiniCPMTTSGenerationOutput(ModelOutput): """ Output class for MiniCPMTTS generation. Args: new_ids (torch.LongTensor): Newly generated audio code sequence, shape (batch_size, sequence_length, num_vq). audio_input_ids (torch.LongTensor): Updated input IDs including condition and generated audio codes, shape (batch_size, full_sequence_length, num_vq). past_key_values (Tuple[Tuple[torch.FloatTensor]]): Tuple containing pre-computed keys and values used for attention mechanism. Each element has shape (batch_size, num_heads, sequence_length, embed_size_per_head). finished (bool): Boolean indicating whether generation is complete. """ new_ids: torch.LongTensor = None audio_input_ids: torch.LongTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_input_ids: Optional[torch.LongTensor] = None finished: bool = None def make_streaming_chunk_mask_inference( tts_text_scope: List[int], tts_text_mask: torch.Tensor, streaming_audio_chunk_size: int = 50, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cuda"), max_sequence_length: int = 4096, ): """ Example: Input sequence: [t1, t2, t3, t4, t5, [Ptts], a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, ...] Output 4D causal mask: ------- text positions ------- [0] <- here is [Stts] [0, 0] <- here is [spk_emb] * N [0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0, 0] ------- audio positions -------- [0, 0, -inf, -inf, -inf, 0] <- here is [Ptts], [Ptts]'s last hidden state should predict the first audio token v- here is [Ptts] [0, 0, -inf, -inf, -inf, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0] [0, 0, -inf, -inf, -inf, 0, 0, 0, 0, 0, 0] # end of first 1s audio chunk [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0 , -inf, -inf, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] """ # Create a complete attention mask for input embeds [batch_size, seq_len], without considering audio mask as audio is always at the end assert tts_text_mask.dtype == torch.int8 padding_mask = torch.ones(max_sequence_length, dtype=torch.int8, device=device) padding_mask[tts_text_scope[0] : tts_text_scope[1]] = tts_text_mask # Initialize a standard upper triangular causal mask min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (max_sequence_length, max_sequence_length), fill_value=min_dtype, dtype=dtype, device=device, ) if max_sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) else: raise ValueError("max_sequence_length of tts could not be 1.") # For each data sample audio_token_start = tts_text_scope[1] audio_duration = max_sequence_length - tts_text_scope[1] # Record which text chunk the current audio chunk can see up to text_pivot = 0 num_valid_text_tokens = torch.sum(tts_text_mask).item() - 1 # [Ptts] excluded # How many audio chunks are in total, the num of buckets should be smaller as possible num_text_tokens_per_audio_chunk = 10 # For each chunk of audio for chunk_idx in range(math.ceil(audio_duration / streaming_audio_chunk_size)): audio_chunk_start = audio_token_start + chunk_idx * streaming_audio_chunk_size audio_chunk_end = audio_token_start + (chunk_idx + 1) * streaming_audio_chunk_size # New text seen by this new audio chunk new_text_this_chunk = num_text_tokens_per_audio_chunk # The right bound of visible text tokens text_pivot = min(new_text_this_chunk + text_pivot, num_valid_text_tokens) # Mask all text chunks after the visible ones # -> [text_pivot, len(tts_text_scope)-1] excluding [Ptts] causal_mask[ audio_chunk_start - 1 : audio_chunk_end - 1, # tts_text_scope[0] + text_pivot: tts_text_scope[1], tts_text_scope[0] + text_pivot : tts_text_scope[1] - 1, ] = min_dtype # Mask the padding parts in tts_text_masks (no position will attend to it) causal_mask[:, padding_mask == 0] = min_dtype # Add extra dimensions, [batch_size, seq_len, seq_len] -> [batch_size, 1, seq_len, seq_len] causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) return causal_mask class MiniCPMTTS(PreTrainedModel): config_class = MiniCPMTTSConfig def __init__(self, config: MiniCPMTTSConfig, audio_tokenizer: None): super().__init__(config) self.use_llm_hidden_state = config.use_llm_hidden_state self.use_text = config.use_text self.streaming = config.streaming self.streaming_text_chunk_min = config.streaming_text_chunk_min self.streaming_text_chunk_max = config.streaming_text_chunk_max self.streaming_audio_chunk_size = config.streaming_audio_chunk_size self.streaming_text_reserved_len = config.streaming_text_reserved_len # streaming tts self.streaming_text_chunk_size = config.streaming_text_chunk_max self.audio_bos_token_id = config.audio_bos_token_id self.num_mel_bins = config.num_mel_bins self.num_vq = config.num_vq self.num_audio_tokens = config.num_audio_tokens self.top_p = config.top_p self.top_k = config.top_k self.repetition_penalty = config.repetition_penalty self.interleaved = config.interleaved self.attention_type = config.attention_type self.recomputed_chunks = config.recomputed_chunks # Two different window size concepts: # 1. chunk_window_size: number of chunks for sliding_recompute mode (default 2) # 2. token_window_size: number of tokens for sliding_window mode (default 300) self.chunk_window_size = config.window_size # chunk-level window for sliding_recompute self.token_window_size = ( config.streaming_sliding_window_audio_window_size ) # token-level window for sliding_window # Legacy aliases (for backward compatibility with existing code) self.window_size = self.chunk_window_size # used in generate_streaming for sliding_recompute self.sliding_window_size = self.token_window_size # used in TTSStreamingGenerator for sliding_window if self.attention_type == "sliding_recompute" and self.chunk_window_size <= self.recomputed_chunks: raise ValueError( f"sliding_recompute requires chunk_window_size > recomputed_chunks, " f"but got chunk_window_size={self.chunk_window_size} and recomputed_chunks={self.recomputed_chunks}" ) if config.backbone_model == "llama": model_config = LlamaConfig( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, num_attention_heads=config.num_attention_heads, num_hidden_layers=config.num_hidden_layers, num_key_value_heads=config.num_key_value_heads, max_position_embeddings=config.max_position_embeddings, attn_implementation=config.attn_implementation, ) self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size) model = LlamaModel(model_config) self.model = model else: raise ValueError(f"Unsupported backbone model: {config.backbone_model}") self.projector_spk = self.create_projector(config) self.projector_semantic = self.create_projector(config) self.audio_tokenizer = audio_tokenizer self.emb_code = nn.ModuleList( [nn.Embedding(config.num_audio_tokens, config.hidden_size) for _ in range(config.num_vq)] ) self.head_code = nn.ModuleList( [ weight_norm( nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False), name="weight", ) for _ in range(config.num_vq) ] ) self.condition_type = config.condition_type return @staticmethod def create_projector(config): if config.projector_type == "mlp": return MultiModalProjector(config.llm_dim, config.hidden_size) elif config.projector_type == "minicpm": return MiniCPMMLP(config) elif config.projector_type == "default": return nn.Linear(config.llm_dim, config.hidden_size, bias=False) else: raise ValueError(f"Unsupported projector type: {config.projector_type}") # non-streaming @torch.inference_mode() def generate( self, inputs_embeds: torch.Tensor, eos_token: Union[int, torch.Tensor], force_no_stop=False, min_new_token=50, max_new_token=2048, show_tqdm=True, streaming=False, text_lengths=None, sampling_params: TTSSamplingParams = TTSSamplingParams(), ): temperature = torch.tensor( [sampling_params.temperature] * self.config.num_vq, dtype=torch.float, device=self.device, ) temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to( inputs_embeds.device ) logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=sampling_params.repetition_penalty, top_p=sampling_params.top_p, top_k=sampling_params.top_k, ) # We only support batch size `1` for now assert inputs_embeds.shape[0] == 1 eos_token = eos_token.to(inputs_embeds.device) finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool() condition_length = inputs_embeds.shape[1] pbar: Optional[tqdm] = None if show_tqdm: pbar = tqdm( total=max_new_token, desc="code", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", ) if streaming: raise NotImplementedError("this kind of streaming is not supported yet") new_tokens = torch.zeros( inputs_embeds.shape[0], max_new_token, self.num_vq, device=inputs_embeds.device, dtype=torch.long, ) past_key_values = None for t in range(max_new_token): audio_bos = False # If this is the first audio token, the case is special if t == 0: audio_bos = True inputs_embeds = inputs_embeds position_ids = torch.tensor( list(range(0, condition_length)), dtype=torch.long, device=self.device, ).unsqueeze(0) if streaming: raise NotImplementedError("this kind of streaming is not supported yet") else: causal_mask_4d = None else: code_emb = [] for q in range(self.num_vq): x = self.emb_code[q](new_tokens[:, t - 1 : t, q]) code_emb.append(x) inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze( 0 ) if streaming: raise NotImplementedError("this kind of streaming is not supported yet") else: causal_mask_4d = None if self.config.backbone_model == "llama": outputs: BaseModelOutputWithPast = self.model( position_ids=position_ids, cache_position=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=causal_mask_4d, use_cache=True, output_attentions=False, # return_dict=True, # Add this to ensure returns dict with past_key_values ) else: raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}") del position_ids del inputs_embeds hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= temperature if not audio_bos: input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) for logitsWarpers in logits_warpers: logits = logitsWarpers(logits_token, logits) del logits_token if t < min_new_token: logits[:, eos_token] = -torch.inf if force_no_stop: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or new_tokens[:, t] = idx_next if t == 0 and finish.any(): break del idx_next if finish.all(): break if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() if not finish.all(): logger.warning(f"incomplete result. hit max_new_token: {max_new_token}") genrated_input_ids = new_tokens[:, 0:t, :] return MiniCPMTTSGenerationOutput( new_ids=genrated_input_ids, audio_input_ids=None, # for update purpose past_key_values=None, # for update purpose past_input_ids=None, # for update purpose finished=finish.all(), ) # fake streaming @torch.inference_mode() def generate_mock_legacy_streaming( self, inputs_embeds: torch.Tensor, eos_token: Union[int, torch.Tensor], force_no_stop=False, min_new_token=50, max_new_token=2048, show_tqdm=True, streaming=False, text_lengths=None, sampling_params: TTSSamplingParams = TTSSamplingParams(), valid_text_length=None, ): assert valid_text_length is not None, "valid_text_length should be not None" tts_text_scope = [0, inputs_embeds.shape[1]] tts_text_mask = torch.zeros(inputs_embeds.shape[1], dtype=torch.int8, device=inputs_embeds.device) tts_text_mask[0:valid_text_length] = 1 tts_text_mask[-1] = 1 # [Ptts] streaming_mask_4d_full = make_streaming_chunk_mask_inference( tts_text_scope=tts_text_scope, tts_text_mask=tts_text_mask, dtype=torch.bfloat16, device=self.device, streaming_audio_chunk_size=50, max_sequence_length=4096, ) temperature = torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.device) temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to( inputs_embeds.device ) logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=sampling_params.repetition_penalty, top_p=sampling_params.top_p, top_k=sampling_params.top_k, ) # We only support batch size `1` for now assert inputs_embeds.shape[0] == 1 eos_token = eos_token.to(inputs_embeds.device) finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool() condition_length = inputs_embeds.shape[1] pbar: Optional[tqdm] = None if show_tqdm: pbar = tqdm( total=max_new_token, desc="code", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", ) new_tokens = torch.zeros( inputs_embeds.shape[0], max_new_token, self.num_vq, device=inputs_embeds.device, dtype=torch.long, ) past_key_values = None for t in range(max_new_token): audio_bos = False if t == 0: audio_bos = True inputs_embeds = inputs_embeds position_ids = torch.tensor( list(range(0, condition_length)), dtype=torch.long, device=self.device, ).unsqueeze(0) causal_mask_4d = streaming_mask_4d_full[:, :, :condition_length, :condition_length] else: code_emb = [] for q in range(self.num_vq): x = self.emb_code[q](new_tokens[:, t - 1 : t, q]) code_emb.append(x) inputs_embeds = torch.stack(code_emb, 3).sum(3) position_ids = torch.tensor([condition_length + t - 1], dtype=torch.long, device=self.device).unsqueeze( 0 ) causal_mask_4d = streaming_mask_4d_full[ :, :, condition_length + t : condition_length + t + 1, : condition_length + t, ] # get length of past_key_values past_key_values_length = past_key_values[0][0].shape[2] assert causal_mask_4d.shape[-1] == (past_key_values_length + 1) if self.config.backbone_model == "llama": outputs: BaseModelOutputWithPast = self.model( position_ids=position_ids, cache_position=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=causal_mask_4d, use_cache=True, output_attentions=False, # return_dict=True, # Add this to ensure returns dict with past_key_values ) else: raise ValueError(f"Unsupported backbone model: {self.config.backbone_model}") del position_ids del inputs_embeds hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= temperature if not audio_bos: input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) for logitsWarpers in logits_warpers: logits = logitsWarpers(logits_token, logits) del logits_token if t < min_new_token: logits[:, eos_token] = -torch.inf if force_no_stop: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or new_tokens[:, t] = idx_next if t == 0 and finish.any(): break del idx_next if finish.all(): break if pbar is not None: pbar.update(1) if pbar is not None: pbar.close() if not finish.all(): logger.warning(f"incomplete result. hit max_new_token: {max_new_token}") genrated_input_ids = new_tokens[:, 0:t, :] return MiniCPMTTSGenerationOutput( new_ids=genrated_input_ids, audio_input_ids=None, # for update purpose past_key_values=None, # for update purpose past_input_ids=None, # for update purpose finished=finish.all(), ) # non-streaming, interleave @torch.inference_mode() def generate_chunk( self, inputs_embeds: torch.Tensor, temperature: torch.Tensor, repetition_penalty: float, eos_token: Union[int, torch.Tensor], force_no_stop=False, max_new_token=500, min_new_tokens=0, past_key_values=None, logits_processors=None, text_start_pos=None, ): """For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like: |Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS| where the last position is the audio BOS token. So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token. """ logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty ) # We only support batch size `1` for now assert inputs_embeds.shape[0] == 1 eos_token = eos_token.to(inputs_embeds.device) finish = torch.zeros(inputs_embeds.shape[0], device=inputs_embeds.device).bool() temperature = (temperature.unsqueeze(0).expand(inputs_embeds.shape[0], -1).contiguous().view(-1, 1)).to( inputs_embeds.device ) condition_length = inputs_embeds.shape[1] new_tokens = torch.zeros( inputs_embeds.shape[0], max_new_token, self.num_vq, device=inputs_embeds.device, dtype=torch.long, ) for t in range(max_new_token): audio_bos = False # If this is the first audio token, the case is special if t == 0: audio_bos = True inputs_embeds_ = inputs_embeds position_ids = torch.tensor( list(range(text_start_pos, text_start_pos + condition_length)), dtype=torch.long, device=self.device, ).unsqueeze(0) else: # Generate the following audio tokens, it is applicable to all other cases, including second and the following calling of `generate` inputs_embeds_ = self.emb_code[0](new_tokens[:, t - 1 : t, 0]) position_ids = torch.tensor( [text_start_pos + condition_length + t - 1], # prefill the previous token dtype=torch.long, device=self.device, ).unsqueeze(0) outputs: BaseModelOutputWithPast = self.model( position_ids=position_ids, # cache_position=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds_, use_cache=True, output_attentions=False, # return_dict=True, # Add this to ensure returns dict with past_key_values ) del position_ids del inputs_embeds_ hidden_states = outputs.last_hidden_state past_key_values = outputs.past_key_values with P.cached(): logits = torch.empty( hidden_states.size(0), hidden_states.size(1), self.num_audio_tokens, self.num_vq, dtype=torch.float, device=self.device, ) for num_vq_iter in range(self.num_vq): x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) logits[..., num_vq_iter] = x del x del hidden_states logits = logits[:, -1].float() logits = logits.permute(0, 2, 1) logits = logits.reshape(-1, logits.size(2)) logits /= temperature if not audio_bos: input_ids_sliced = new_tokens[:, 0:t].permute(0, 2, 1) # get previous t new tokens logits_token = input_ids_sliced.reshape( input_ids_sliced.size(0) * input_ids_sliced.size(1), -1, ).to(self.device) del input_ids_sliced for logitsProcessors in logits_processors: logits = logitsProcessors(logits_token, logits) del logits_token if force_no_stop or t < min_new_tokens: logits[:, eos_token] = -torch.inf scores = F.softmax(logits, dim=-1) del logits idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) del scores idx_next = idx_next.view(-1, self.num_vq) finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or new_tokens[:, t] = idx_next if t == 0 and finish.any(): break del idx_next if finish.all(): break # The latest generated token is not in the range returned this time. If it is an eos token, it is not returned. If it is a normal token, it is not returned. genrated_input_ids = new_tokens[:, 0:t, :] return genrated_input_ids, past_key_values @torch.inference_mode() def interleaved_generate( self, spk_embeds: torch.Tensor, conditions: List[torch.Tensor], temperature: torch.Tensor, repetition_penalty: float, eos_token: Union[int, torch.Tensor], **kwargs, ): """ For inputs_embeds, it should be like [bs=1, seq_len, hidden_dim], its content is like: |Text BOS|Spk embeds|Text-Hidden states Interleave (if applicable)|Audio BOS| where the last position is the audio BOS token. So, the first iteration in generation directly forward the model with inputs_embeds, and the last hidden states of the last position (Audio BOS) will be decoded to get the first audio token. """ temperature = torch.tensor([temperature], dtype=torch.float, device=self.device) logits_warpers, logits_processors = gen_logits( num_code=self.config.num_audio_tokens, repetition_penalty=repetition_penalty, ) eos_token = eos_token.to(conditions[0].device) num_chunks = len(conditions) text_start_pos = 0 last_window_size = 0 past_key_values = None for idx in range(num_chunks): condition = conditions[idx].to(conditions[0].device) if self.attention_type == "sliding_recompute": recomputed_conditions = [] if ( idx >= self.window_size and (idx - self.recomputed_chunks) % (self.window_size - self.recomputed_chunks) == 0 ): for i in range(self.recomputed_chunks): recomputed_conditions.append(conditions[idx - self.recomputed_chunks + i]) recomputed_conditions.append( self.emb_code[0](generated_tokens[-self.recomputed_chunks + i][:, :, 0]) ) recomputed_conditions.append(condition) condition = torch.cat(recomputed_conditions, dim=1) text_start_pos = 0 new_tokens, old_kv = self.generate_chunk( inputs_embeds=condition, temperature=temperature, repetition_penalty=repetition_penalty, eos_token=eos_token, force_no_stop=False, max_new_token=500, past_key_values=None, logits_processors=logits_processors, text_start_pos=text_start_pos, ) else: new_tokens, old_kv = self.generate_chunk( inputs_embeds=condition, temperature=temperature, repetition_penalty=repetition_penalty, eos_token=eos_token, force_no_stop=False, max_new_token=500, past_key_values=past_key_values, logits_processors=logits_processors, text_start_pos=text_start_pos, ) else: new_tokens, old_kv = self.generate_chunk( inputs_embeds=condition, temperature=temperature, repetition_penalty=repetition_penalty, eos_token=eos_token, force_no_stop=False, max_new_token=500, past_key_values=past_key_values, logits_processors=logits_processors, text_start_pos=text_start_pos, ) past_key_values = [] if self.attention_type == "sliding_window" and idx >= 1: for layer_idx in range(len(old_kv)): past_key_values.append( ( old_kv[layer_idx][0][:, :, last_window_size:, :], old_kv[layer_idx][1][:, :, last_window_size:, :], ) ) else: past_key_values = old_kv last_window_size = condition.shape[1] + new_tokens.shape[1] text_start_pos += last_window_size if idx == 0: generated_tokens = [new_tokens] else: generated_tokens.append(new_tokens) return MiniCPMTTSGenerationOutput(new_ids=torch.cat(generated_tokens, dim=1), finished=True) class CustomRepetitionPenaltyLogitsProcessorRepeat: def __init__(self, penalty: float, max_input_ids: int, past_window: int): if not isinstance(penalty, float) or not (penalty > 0): raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") self.penalty = penalty self.max_input_ids = max_input_ids self.past_window = past_window def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.size(1) > self.past_window: input_ids = input_ids.narrow(1, -self.past_window, self.past_window) freq = F.one_hot(input_ids, scores.size(1)).sum(1) if freq.size(0) > self.max_input_ids: freq.narrow(0, self.max_input_ids, freq.size(0) - self.max_input_ids).zero_() alpha = torch.pow(self.penalty, freq) scores = scores.contiguous() inp = scores.multiply(alpha) oth = scores.divide(alpha) con = scores < 0 out = torch.where(con, inp, oth) del inp, oth, scores, con, alpha return out def gen_logits(num_code: int, top_p=0.7, top_k=20, repetition_penalty=1.0): logits_warpers = [] if top_p is not None: logits_warpers.append(TopPLogitsWarper(top_p, min_tokens_to_keep=3)) if top_k is not None: logits_warpers.append(TopKLogitsWarper(top_k, min_tokens_to_keep=3)) logits_processors = [] if repetition_penalty is not None and repetition_penalty != 1: logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, num_code, 16)) return logits_warpers, logits_processors # Copy and modified from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, **kwargs, ): if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens else: cache_length = past_length = past_key_values[0][0].shape[2] # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This clo≠clo≠clone call is needed to avoid recapturing cuda graphs with →rch.comπ≤→rch.comπ≤torch.compile's mode=reduce−overheadmode=reduce-overheadmode="reduce-overhead, as otherwise the input positionidspositionidsposition_ids would have various stride during the decoding. Here, simply using .contiguous().contiguous().contiguous() is not sufficient as in the batch size = 1 case, positionidspositionidsposition_ids is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if ∈putsembeds∈putsembedsinputs_embeds are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for positionidspositionidsposition_ids. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape device = model_inputs["inputs_embeds"].device else: batch_size, sequence_length = model_inputs["input_ids"].shape device = model_inputs["input_ids"].device dtype = self.lm_head.weight.dtype min_dtype = torch.finfo(dtype).min from transformers.models.paligemma.modeling_paligemma import ( _prepare_4d_causal_attention_mask_with_cache_position, ) attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_length(), dtype=dtype, device=device, min_dtype=min_dtype, cache_position=cache_position, batch_size=batch_size, ) model_inputs.update( { "position_ids": position_ids, # "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, } ) return model_inputs