File size: 6,001 Bytes
da16492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
---
license: mit
pipeline_tag: video-text-to-text
library_name: transformers
---

# M4-Audio-LongVA-7B-Qwen2

Enhancing Omni Interactive Capabilities in MLLM

This repository contains the model described in [OmniMMI: A Comprehensive Multi-modal Interaction Benchmark in Streaming Video Contexts](https://huggingface.co/papers/2503.22952).
The code can be found at https://github.com/patrick-tssn/M4.

![images](./assets/framework.png)

M4-Audio-7B is an extension of [LongVA-7B](https://github.com/EvolvingLMMs-Lab/LongVA), further trained using the [M4-IT](https://huggingface.co/datasets/ColorfulAI/M4-IT) dataset, which comprises 9,963 visual-audio instruction tuning instances. This training was conducted without any special modifications to the existing training pipeline.


## Usage


*Please refer to [M4](https://github.com/patrick-tssn/M4) to install relvevant packages*

```python
import os
from PIL import Image
import numpy as np
import torchaudio
import torch
from decord import VideoReader, cpu
import whisper
# fix seed
torch.manual_seed(0)

from intersuit.model.builder import load_pretrained_model
from intersuit.mm_utils import tokenizer_image_speech_tokens, process_images
from intersuit.constants import IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX

import ChatTTS
chat = ChatTTS.Chat()
chat.load(source='local', compile=True)

import warnings
warnings.filterwarnings("ignore")

model_path = "checkpoints/M4-Audio-LongVA-7B-Qwen2"
video_path = "local_demo/assets/water.mp4"
audio_path = "local_demo/wav/infer.wav"
new_audio_path = "local_demo/wav/new_infer.wav"
max_frames_num = 16 # you can change this to several thousands so long you GPU memory can handle it :)
gen_kwargs = {"do_sample": True, "temperature": 0.5, "top_p": None, "num_beams": 1, "use_cache": True, "max_new_tokens": 1024}
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, "llava_qwen", device_map="cuda:0", attn_implementation="eager")

# original query
query = "Give a detailed caption of the video as if I am blind."
query = None # comment this to use ChatTTS to convert the query to audio
prompt = "<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<image><|im_end|>
<|im_start|>user
<speech>
<|im_end|>
<|im_start|>assistant
"
input_ids = tokenizer_image_speech_tokens(prompt, tokenizer, IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
pad_token_ids = (tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id)
attention_masks = input_ids.ne(pad_token_ids).to(input_ids.device)
# audio input
if query is not None:
    audio_path = "./local_demo/wav/" + "infer.wav"
    if os.path.exists(audio_path): os.remove(audio_path) # refresh
    if not os.path.exists(audio_path):
        wav = chat.infer(query)
        try:
            torchaudio.save(audio_path, torch.from_numpy(wav).unsqueeze(0), 24000)
        except:
            torchaudio.save(audio_path, torch.from_numpy(wav), 24000)
speech = whisper.load_audio(audio_path)
speech = whisper.pad_or_trim(speech)
speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0).to(device=model.device, dtype=torch.float16)
speech_length = torch.LongTensor([speech.shape[0]]).to(model.device)

# new query
new_query = "How many people in the video?"
new_query = "Okay, I see."
new_query = "Sorry to interrupt."
new_query_pos = 10 # which token encounter the new query
new_query = None # comment this to use ChatTTS to convert the query to audio
new_prompt = "<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
<speech>
<|im_end|>
<|im_start|>assistant
"
new_input_ids = tokenizer_image_speech_tokens(new_prompt, tokenizer, IMAGE_TOKEN_INDEX, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device)
# audio input
if new_query is not None:
    new_audio_path = "./local_demo/wav/" + "new_infer.wav"
    if os.path.exists(new_audio_path): os.remove(new_audio_path) # refresh
    if not os.path.exists(new_audio_path):
        wav = chat.infer(new_query)
        try:
            torchaudio.save(new_audio_path, torch.from_numpy(wav).unsqueeze(0), 24000)
        except:
            torchaudio.save(new_audio_path, torch.from_numpy(wav), 24000)
new_speech = whisper.load_audio(new_audio_path)
new_speech = whisper.pad_or_trim(new_speech)
new_speech = whisper.log_mel_spectrogram(new_speech, n_mels=128).permute(1, 0).to(device=model.device, dtype=torch.float16)
new_speech_length = torch.LongTensor([new_speech.shape[0]]).to(model.device)

#video input
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
video_tensor = image_processor.preprocess(frames, return_tensors="pt")["pixel_values"].to(model.device, dtype=torch.bfloat16)


with torch.inference_mode():
    output_ids = model.generate_parallel(input_ids, 
                                attention_mask=attention_masks,
                                images=[video_tensor], 
                                modalities=["video"], 
                                speeches=speech.unsqueeze(0), 
                                speech_lengths=speech_length,
                                new_query=new_input_ids,
                                new_query_pos=new_query_pos,
                                new_speeches=new_speech.unsqueeze(0),
                                new_speech_lengths=new_speech_length,
                                query_str=query,
                                new_query_str=new_query,
                                tokenizer=tokenizer,
                                **gen_kwargs)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()

```


For more information about the interaction inference pipeline, please visit the [M4 GitHub repository](https://github.com/patrick-tssn/M4).