hiko1999 commited on
Commit
22f3ff8
·
verified ·
1 Parent(s): 70ae9d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -31
app.py CHANGED
@@ -3,48 +3,109 @@ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoPro
3
  from qwen_vl_utils import process_vision_info
4
  import gradio as gr
5
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Hugging Face 模型仓库路径
8
- model_path = "hiko1999/Qwen2-Wildfire-2B" # 替换为你的模型路径
9
 
10
- # 加载 Hugging Face 上的模型和 processor
 
11
  tokenizer = AutoTokenizer.from_pretrained(model_path)
12
- model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16) # 移除 device_map 参数以避免自动分配到 GPU
 
 
 
 
13
  processor = AutoProcessor.from_pretrained(model_path)
 
14
 
15
  # 定义预测函数
16
  def predict(image):
17
- # 将上传的图片处理为模型需要的格式
18
- messages = [{"role": "user",
19
- "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}]}]
20
-
21
- # 处理图片输入
22
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
23
- image_inputs, video_inputs = process_vision_info(messages)
24
- inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
25
-
26
- # 将数据转移到 CPU
27
- inputs = inputs.to("cpu") # 使用 CPU 而不是 CUDA
28
-
29
- # 生成模型输出
30
- generated_ids = model.generate(**inputs, max_new_tokens=128)
31
- generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
32
- output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True,
33
- clean_up_tokenization_spaces=False)
34
-
35
- return output_text[0] # 返回生成的文本
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Gradio界面
38
  def gradio_interface(image):
 
39
  result = predict(image)
40
- return f"预测结果:{result}"
41
 
42
- # 创建Gradio接口
43
- interface = gr.Interface(fn=gradio_interface,
44
- inputs=gr.Image(type="pil"), # 输入的图像
45
- outputs="text", # 输出结果
46
- title="火灾场景多模态模型预测",
47
- description="上传图片进行火灾预测。")
 
 
 
 
48
 
49
  # 启动接口
50
- interface.launch()
 
 
3
  from qwen_vl_utils import process_vision_info
4
  import gradio as gr
5
  from PIL import Image
6
+ from huggingface_hub import login
7
+ import os
8
+
9
+ # ========== 使用你的 secret 名称 fmv 登录 ==========
10
+ token = os.getenv("fmv") # 读取名为 fmv 的 secret
11
+ if token:
12
+ login(token=token)
13
+ print("成功使用 token 登录!")
14
+ else:
15
+ print("警告:未找到 token")
16
+ # ==========================================
17
 
18
  # Hugging Face 模型仓库路径
19
+ model_path = "hiko1999/Qwen2-Wildfire-2B"
20
 
21
+ # 加载模型和 processor
22
+ print(f"正在加载模型: {model_path}")
23
  tokenizer = AutoTokenizer.from_pretrained(model_path)
24
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
25
+ model_path,
26
+ torch_dtype=torch.bfloat16,
27
+ device_map="cpu"
28
+ )
29
  processor = AutoProcessor.from_pretrained(model_path)
30
+ print("模型加载完成!")
31
 
32
  # 定义预测函数
33
  def predict(image):
34
+ """处理图片并生成描述"""
35
+ if image is None:
36
+ return "错误:未上传图片"
37
+
38
+ try:
39
+ # 构建消息
40
+ messages = [
41
+ {
42
+ "role": "user",
43
+ "content": [
44
+ {"type": "image", "image": image},
45
+ {"type": "text", "text": "请描述这张图片中的火灾情况。"}
46
+ ]
47
+ }
48
+ ]
49
+
50
+ # 处理输入
51
+ text = processor.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
56
+ image_inputs, video_inputs = process_vision_info(messages)
57
+ inputs = processor(
58
+ text=[text],
59
+ images=image_inputs,
60
+ videos=video_inputs,
61
+ padding=True,
62
+ return_tensors="pt"
63
+ )
64
+
65
+ # 确保在 CPU 上运行
66
+ inputs = inputs.to("cpu")
67
+
68
+ # 生成输出
69
+ generated_ids = model.generate(
70
+ **inputs,
71
+ max_new_tokens=256,
72
+ do_sample=True,
73
+ temperature=0.7
74
+ )
75
+
76
+ # 解码输出
77
+ generated_ids_trimmed = [
78
+ out_ids[len(in_ids):]
79
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
80
+ ]
81
+ output_text = processor.batch_decode(
82
+ generated_ids_trimmed,
83
+ skip_special_tokens=True,
84
+ clean_up_tokenization_spaces=False
85
+ )
86
+
87
+ return output_text[0]
88
+
89
+ except Exception as e:
90
+ return f"预测失败: {str(e)}"
91
 
92
+ # Gradio 界面函数
93
  def gradio_interface(image):
94
+ """Gradio 界面的主函数"""
95
  result = predict(image)
96
+ return result
97
 
98
+ # 创建 Gradio 界面
99
+ interface = gr.Interface(
100
+ fn=gradio_interface,
101
+ inputs=gr.Image(type="pil", label="上传火灾图片"),
102
+ outputs=gr.Textbox(label="AI 分析结果", lines=10),
103
+ title="🔥 火灾场景智能分析系统",
104
+ description="上传火灾相关图片,AI 将自动分析并描述火灾情况。",
105
+ theme=gr.themes.Soft(),
106
+ allow_flagging="never"
107
+ )
108
 
109
  # 启动接口
110
+ if __name__ == "__main__":
111
+ interface.launch()