File size: 4,598 Bytes
e8dd7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from PIL import Image
import gradio as gr
import torch

# 1. 이미지 캡셔닝 모델 로드 (BLIP)
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# 2. 동화 생성 모델 (Flan-T5)
story_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
story_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
story_generator = pipeline("text2text-generation", model=story_model, tokenizer=story_tokenizer)

# 3. 동화 생성 함수
def generate_fairytale(image):
    # (1) 이미지 → 캡션
    inputs = caption_processor(images=image, return_tensors="pt").to(caption_model.device)
    outputs = caption_model.generate(**inputs)
    caption = caption_processor.decode(outputs[0], skip_special_tokens=True).strip()

    # (2) 캡션이 짧으면 보완
    if len(caption.split()) < 5:
        caption += ". They look like magical characters from a fantasy story."

    # (3) 프롬프트 설정
    prompt = f"""
Write a magical and imaginative fairytale for children based on the following image description.

Description: "{caption}"

Your story should:
- Be at least 3 paragraphs long
- Start with "Once upon a time"
- Include fantasy, adventure, or mystery elements
- Be creative and heartwarming

Story:
"""

    # (4) 동화 생성 (길이 충분히 늘리기)
    story = story_generator(prompt, max_length=500, do_sample=True)[0]['generated_text']

    return caption, story

# 4. Gradio 인터페이스
interface = gr.Interface(
    fn=generate_fairytale,
    inputs=gr.Image(type="pil", label="🖼️ Upload an Image"),
    outputs=[
        gr.Textbox(label="📌 Image Description"),
        gr.Textbox(label="📖 Generated Fairytale")
    ],
    title="🌟 AI Fairytale Generator from Image",
    description="Upload an image and get a rich fairytale story created from it!",
    theme="soft"
)

# 5. 실행
interface.launch(share=True, debug=True, inbrowser=True)


# from transformers import BlipProcessor, BlipForConditionalGeneration
# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# from PIL import Image
# import gradio as gr
# import torch

# # 1. 이미지 설명 생성 모델 (BLIP)
# caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# # 2. 영어 동화 생성 모델 (FLAN-T5)
# story_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
# story_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

# # 3. 동화 생성 함수
# def generate_fairytale(image):
#     # 1. 이미지 → 텍스트 설명 생성
#     inputs = caption_processor(images=image, return_tensors="pt").to(caption_model.device)
#     outputs = caption_model.generate(**inputs)
#     caption = caption_processor.decode(outputs[0], skip_special_tokens=True).strip()

#     # 2. 프롬프트 구성 (중복 방지 조건 추가)
#     prompt = f"""Write a magical and imaginative children's story based on the following image description. 
# Description: "{caption}"
# The story should be at least 3 paragraphs and must not repeat any sentences. 
# Story:"""

#     # 3. 텍스트 생성
#     input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids.to(story_model.device)
#     output_ids = story_model.generate(
#         input_ids,
#         max_length=600,             # 더 길고 풍부한 텍스트
#         num_beams=4,                # 빔 탐색
#         no_repeat_ngram_size=3,     # 반복 방지
#         repetition_penalty=1.3,     # 반복 패널티
#         early_stopping=True,
#         do_sample=False             # 확정적 결과
#     )
#     story = story_tokenizer.decode(output_ids[0], skip_special_tokens=True)

#     return caption, story

# # 4. Gradio UI 인터페이스
# interface = gr.Interface(
#     fn=generate_fairytale,
#     inputs=gr.Image(type="pil", label="📷 Upload an Image"),
#     outputs=[
#         gr.Textbox(label="📌 Image Description"),
#         gr.Textbox(label="📖 Generated Fairytale")
#     ],
#     title="🌟 AI Fairytale Generator from Image",
#     description="Upload an image and get a rich fairytale story created from it!",
#     theme="default"
# )

# # 5. 실행
# interface.launch(share=True)