wildfire_teller / app.py
hiko1999's picture
Update app.py
bb49a7b verified
raw
history blame
3.42 kB
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import gradio as gr
from PIL import Image
from huggingface_hub import login
import os
import warnings
# 抑制警告
warnings.filterwarnings("ignore", category=RuntimeWarning)
# ========== 使用你的 secret 名称 fmv 登录 ==========
token = os.getenv("fmv")
if token:
login(token=token)
print("Successfully logged in with token!")
else:
print("Warning: Token not found")
# ==========================================
# Hugging Face model repository path
model_path = "hiko1999/Qwen2-Wildfire-2B"
# Load model and processor
print(f"Loading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
processor = AutoProcessor.from_pretrained(model_path)
print("Model loaded successfully!")
# Define prediction function
def predict(image):
"""Process image and generate description"""
if image is None:
return "Error: No image uploaded"
try:
# Build message with English prompt
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this wildfire scene in English. Include details about the fire intensity, affected area, and visible environmental conditions."}
]
}
]
# Process input
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
# Ensure running on CPU
inputs = inputs.to("cpu")
# Generate output
generated_ids = model.generate(
**inputs,
max_new_tokens=256,
do_sample=True,
temperature=0.7
)
# Decode output
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0]
except Exception as e:
return f"Prediction failed: {str(e)}"
# Gradio interface function
def gradio_interface(image):
"""Main function for Gradio interface"""
result = predict(image)
return result
# Create Gradio interface (all in English)
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil", label="Upload Wildfire Image"),
outputs=gr.Textbox(label="AI Analysis Result", lines=10),
title="🔥 Wildfire Scene Analysis System",
description="Upload a wildfire-related image and AI will automatically analyze and describe the fire situation in English."
)
# Launch interface
if __name__ == "__main__":
interface.launch(share=False)