fairytale_generator / 10_fairy_tale_book.py
Antonio0616's picture
Upload 2 files
ce09f00 verified
import gradio as gr
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image
import torch
# ✅ 모델 로딩
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to("cuda" if torch.cuda.is_available() else "cpu")
# 📌 Step 1: 이미지 설명 추출 (이미지 → 캡션)
def extract_caption(image):
inputs = processor(images=image, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=50)
caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
return caption
# 📌 Step 2: 캡션 기반 동화 프롬프트 생성
def build_prompt_from_caption(caption):
return (
f"Write a magical and fun children's fairytale based on this description: \"{caption}\". "
"Start with 'Once upon a time' and continue for at least 7 sentences. "
"Include characters, emotions, the setting, and a twist. Make it feel like a real story."
)
# 📌 Step 3: 동화 생성 (캡션 + 프롬프트 → 텍스트)
def generate_fairytale(image):
caption = extract_caption(image)
prompt = build_prompt_from_caption(caption)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=400,
do_sample=True,
temperature=0.95,
top_p=0.9,
)
story = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
return story
# 🌐 Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🧚‍♂️ 이미지 기반 AI 동화 생성기\n사진을 업로드하면 동화로 바꿔드립니다!")
with gr.Row():
image_input = gr.Image(type="pil", label="📸 이미지 업로드")
with gr.Row():
generate_button = gr.Button("✨ 동화 만들기")
with gr.Row():
output_text = gr.Textbox(label="📖 생성된 동화", lines=10)
generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[output_text])
# 실행
if __name__ == "__main__":
demo.launch(share=True)