fairytale_generator / blip-image-captioning-base02.py
gamin
base02 파일 추가 (캡션, 프롬프트 변경)
e8dd7e6
raw
history blame
4.6 kB
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from PIL import Image
import gradio as gr
import torch
# 1. 이미지 캡셔닝 모델 로드 (BLIP)
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# 2. 동화 생성 모델 (Flan-T5)
story_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
story_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
story_generator = pipeline("text2text-generation", model=story_model, tokenizer=story_tokenizer)
# 3. 동화 생성 함수
def generate_fairytale(image):
# (1) 이미지 → 캡션
inputs = caption_processor(images=image, return_tensors="pt").to(caption_model.device)
outputs = caption_model.generate(**inputs)
caption = caption_processor.decode(outputs[0], skip_special_tokens=True).strip()
# (2) 캡션이 짧으면 보완
if len(caption.split()) < 5:
caption += ". They look like magical characters from a fantasy story."
# (3) 프롬프트 설정
prompt = f"""
Write a magical and imaginative fairytale for children based on the following image description.
Description: "{caption}"
Your story should:
- Be at least 3 paragraphs long
- Start with "Once upon a time"
- Include fantasy, adventure, or mystery elements
- Be creative and heartwarming
Story:
"""
# (4) 동화 생성 (길이 충분히 늘리기)
story = story_generator(prompt, max_length=500, do_sample=True)[0]['generated_text']
return caption, story
# 4. Gradio 인터페이스
interface = gr.Interface(
fn=generate_fairytale,
inputs=gr.Image(type="pil", label="🖼️ Upload an Image"),
outputs=[
gr.Textbox(label="📌 Image Description"),
gr.Textbox(label="📖 Generated Fairytale")
],
title="🌟 AI Fairytale Generator from Image",
description="Upload an image and get a rich fairytale story created from it!",
theme="soft"
)
# 5. 실행
interface.launch(share=True, debug=True, inbrowser=True)
# from transformers import BlipProcessor, BlipForConditionalGeneration
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# from PIL import Image
# import gradio as gr
# import torch
# # 1. 이미지 설명 생성 모델 (BLIP)
# caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# # 2. 영어 동화 생성 모델 (FLAN-T5)
# story_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
# story_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
# # 3. 동화 생성 함수
# def generate_fairytale(image):
# # 1. 이미지 → 텍스트 설명 생성
# inputs = caption_processor(images=image, return_tensors="pt").to(caption_model.device)
# outputs = caption_model.generate(**inputs)
# caption = caption_processor.decode(outputs[0], skip_special_tokens=True).strip()
# # 2. 프롬프트 구성 (중복 방지 조건 추가)
# prompt = f"""Write a magical and imaginative children's story based on the following image description.
# Description: "{caption}"
# The story should be at least 3 paragraphs and must not repeat any sentences.
# Story:"""
# # 3. 텍스트 생성
# input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids.to(story_model.device)
# output_ids = story_model.generate(
# input_ids,
# max_length=600, # 더 길고 풍부한 텍스트
# num_beams=4, # 빔 탐색
# no_repeat_ngram_size=3, # 반복 방지
# repetition_penalty=1.3, # 반복 패널티
# early_stopping=True,
# do_sample=False # 확정적 결과
# )
# story = story_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# return caption, story
# # 4. Gradio UI 인터페이스
# interface = gr.Interface(
# fn=generate_fairytale,
# inputs=gr.Image(type="pil", label="📷 Upload an Image"),
# outputs=[
# gr.Textbox(label="📌 Image Description"),
# gr.Textbox(label="📖 Generated Fairytale")
# ],
# title="🌟 AI Fairytale Generator from Image",
# description="Upload an image and get a rich fairytale story created from it!",
# theme="default"
# )
# # 5. 실행
# interface.launch(share=True)