Spaces:
Runtime error
Runtime error
File size: 5,845 Bytes
2c99f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# import gradio as gr
# from transformers import (
# Blip2Processor,
# Blip2ForConditionalGeneration,
# AutoTokenizer,
# AutoModelForCausalLM,
# )
# from PIL import Image
# import torch
# # Set device
# device = "cuda" if torch.cuda.is_available() else "cpu"
# # Load image captioning model (BLIP-2)
# processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
# blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
# # Load text generation model (LLM)
# llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2").to(device)
# # Step 1: Generate image caption
# def extract_caption(image):
# inputs = processor(images=image, return_tensors="pt").to(device)
# outputs = blip_model.generate(**inputs, max_new_tokens=50)
# caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
# return caption
# # Step 2: Build fairytale prompt
# def build_prompt(caption):
# return (
# f"Based on the image description: \"{caption}\", write a children's fairytale.\n"
# "The story must:\n"
# "- Start with 'Once upon a time'\n"
# "- Be at least 10 full sentences long\n"
# "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n"
# "- Avoid mentions of babies or unrelated royalty unless relevant\n"
# "Here is the story:\nOnce upon a time"
# )
# # Step 3: Generate story
# def generate_fairytale(image):
# caption = extract_caption(image)
# prompt = build_prompt(caption)
# inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
# output = llm_model.generate(
# **inputs,
# max_new_tokens=500,
# do_sample=True,
# temperature=0.9,
# top_p=0.95,
# pad_token_id=llm_tokenizer.eos_token_id
# )
# result = llm_tokenizer.decode(output[0], skip_special_tokens=True)
# # Trim to only the story
# if "Once upon a time" in result:
# return "Once upon a time" + result.split("Once upon a time", 1)[-1].strip()
# else:
# return f"⚠️ Failed to generate story.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}"
# # Gradio interface
# with gr.Blocks() as demo:
# gr.Markdown("## 📖 AI Fairytale Generator\nUpload an image and get a magical story!")
# with gr.Row():
# image_input = gr.Image(type="pil", label="Upload an image")
# with gr.Row():
# generate_button = gr.Button("✨ Generate Fairytale")
# with gr.Row():
# output_text = gr.Textbox(label="Generated Story", lines=20)
# generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[output_text])
# if __name__ == "__main__":
# demo.launch(share=True)
import gradio as gr
from transformers import (
Blip2Processor,
Blip2ForConditionalGeneration,
AutoTokenizer,
AutoModelForCausalLM,
)
from PIL import Image
import torch
# 디바이스 설정 (Mistral은 CPU로 강제 설정)
device = "cpu"
# 1️⃣ 이미지 설명 생성 모델 로드 (BLIP-2)
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
# 2️⃣ 동화 생성 모델 로드 (Mistral-7B, CPU)
llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="cpu")
# 3️⃣ 이미지 → 설명
def extract_caption(image):
inputs = processor(images=image, return_tensors="pt").to(device)
outputs = blip_model.generate(**inputs, max_new_tokens=50)
caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return caption
# 4️⃣ 프롬프트 구성
def build_prompt(caption):
return (
f"Based on the image description: \"{caption}\", write a children's fairytale.\n"
"The story must:\n"
"- Start with 'Once upon a time'\n"
"- Be at least 10 full sentences long\n"
"- Include named characters, a clear setting, emotions, a challenge, and a resolution\n"
"- Avoid unrelated royalty or babies unless relevant\n"
"Here is the story:\nOnce upon a time"
)
# 5️⃣ 전체 동화 생성
def generate_fairytale(image):
caption = extract_caption(image)
prompt = build_prompt(caption)
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
output = llm_model.generate(
**inputs,
max_new_tokens=500,
do_sample=True,
temperature=0.9,
top_p=0.95,
pad_token_id=llm_tokenizer.eos_token_id
)
result = llm_tokenizer.decode(output[0], skip_special_tokens=True)
# "Once upon a time" 이후만 추출
if "Once upon a time" in result:
return caption, "Once upon a time" + result.split("Once upon a time", 1)[-1].strip()
else:
return caption, f"⚠️ Story generation failed.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}"
# 6️⃣ Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🧚 AI Fairytale Generator (Mistral CPU ver.)")
gr.Markdown("Upload an image and receive a magical children's fairytale based on it ✨")
image_input = gr.Image(type="pil", label="🖼️ Upload an Image")
generate_button = gr.Button("✨ Generate Fairytale")
caption_output = gr.Textbox(label="📌 Image Description", lines=2)
story_output = gr.Textbox(label="📖 Generated Fairytale", lines=20)
generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[caption_output, story_output])
if __name__ == "__main__":
demo.launch(share=True)
|