File size: 6,621 Bytes
afe370f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b92d046f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['VLLM_USE_V1'] = '0'\n",
    "os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'\n",
    "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = \"0\"\n",
    "import torch\n",
    "import warnings\n",
    "import numpy as np\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "warnings.filterwarnings('ignore', category=DeprecationWarning)\n",
    "warnings.filterwarnings('ignore', category=FutureWarning)\n",
    "warnings.filterwarnings('ignore', category=UserWarning)\n",
    "\n",
    "from qwen_omni_utils import process_mm_info\n",
    "from transformers import Qwen3OmniMoeProcessor\n",
    "\n",
    "def _load_model_processor():\n",
    "    if USE_TRANSFORMERS:\n",
    "        from transformers import Qwen3OmniMoeForConditionalGeneration\n",
    "        if TRANSFORMERS_USE_FLASH_ATTN2:\n",
    "            model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(MODEL_PATH,\n",
    "                                                                         dtype='auto',\n",
    "                                                                         attn_implementation='flash_attention_2',\n",
    "                                                                         device_map=\"auto\")\n",
    "        else:\n",
    "            model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(MODEL_PATH, device_map=\"auto\", dtype='auto')\n",
    "    else:\n",
    "        from vllm import LLM\n",
    "        model = LLM(\n",
    "            model=MODEL_PATH, trust_remote_code=True, gpu_memory_utilization=0.95,\n",
    "            tensor_parallel_size=torch.cuda.device_count(),\n",
    "            limit_mm_per_prompt={'image': 1, 'video': 3, 'audio': 3},\n",
    "            max_num_seqs=1,\n",
    "            max_model_len=8192,\n",
    "            seed=1234,\n",
    "        )\n",
    "\n",
    "    processor = Qwen3OmniMoeProcessor.from_pretrained(MODEL_PATH)\n",
    "    return model, processor\n",
    "\n",
    "def run_model(model, processor, messages, return_audio, use_audio_in_video):\n",
    "    if USE_TRANSFORMERS:\n",
    "        text = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)\n",
    "        audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video)\n",
    "        inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors=\"pt\", padding=True, use_audio_in_video=use_audio_in_video)\n",
    "        inputs = inputs.to(model.device).to(model.dtype)\n",
    "        text_ids, audio = model.generate(**inputs,\n",
    "                                            thinker_return_dict_in_generate=True,\n",
    "                                            thinker_max_new_tokens=8192,\n",
    "                                            thinker_do_sample=True,\n",
    "                                            thinker_top_p=0.95,\n",
    "                                            thinker_top_k=20,\n",
    "                                            thinker_temperature=0.6,\n",
    "                                            speaker=\"Chelsie\",\n",
    "                                            use_audio_in_video=use_audio_in_video,\n",
    "                                            return_audio=return_audio)\n",
    "        response = processor.batch_decode(text_ids.sequences[:, inputs[\"input_ids\"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]\n",
    "        if audio is not None:\n",
    "            audio = np.array(audio.reshape(-1).detach().cpu().numpy() * 32767).astype(np.int16)\n",
    "        return response, audio\n",
    "    else:\n",
    "        from vllm import SamplingParams\n",
    "        sampling_params = SamplingParams(temperature=0.6, top_p=0.95, top_k=20, max_tokens=4096)\n",
    "        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
    "        audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video)\n",
    "        inputs = {'prompt': text, 'multi_modal_data': {}, \"mm_processor_kwargs\": {\"use_audio_in_video\": use_audio_in_video}}\n",
    "        if images is not None: inputs['multi_modal_data']['image'] = images\n",
    "        if videos is not None: inputs['multi_modal_data']['video'] = videos\n",
    "        if audios is not None: inputs['multi_modal_data']['audio'] = audios\n",
    "        outputs = model.generate(inputs, sampling_params=sampling_params)\n",
    "        response = outputs[0].outputs[0].text\n",
    "        return response, None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d37dcedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import librosa\n",
    "import audioread\n",
    "\n",
    "from IPython.display import Audio\n",
    "\n",
    "MODEL_PATH = \"NandemoGHS/Anime-Speech-Japanese-Captioner-FP8-DYNAMIC\"\n",
    "\n",
    "USE_TRANSFORMERS = False\n",
    "TRANSFORMERS_USE_FLASH_ATTN2 = True\n",
    "\n",
    "model, processor = _load_model_processor()\n",
    "\n",
    "USE_AUDIO_IN_VIDEO = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bf60bf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "audio_path = \"https://huggingface.co/NandemoGHS/Anime-Speech-Japanese-Captioner/resolve/main/examples/example1.wav\"\n",
    "\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": [\n",
    "            {\"type\": \"audio\", \"audio\": audio_path}\n",
    "        ]\n",
    "    }\n",
    "]\n",
    "\n",
    "display(Audio(librosa.load(audioread.ffdec.FFmpegAudioFile(audio_path), sr=16000)[0], rate=16000))\n",
    "\n",
    "response, _ = run_model(model=model, messages=messages, processor=processor, return_audio=False, use_audio_in_video=USE_AUDIO_IN_VIDEO)\n",
    "\n",
    "print(response)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv (3.10.12)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}