File size: 5,845 Bytes
2c99f9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# import gradio as gr
# from transformers import (
#     Blip2Processor,
#     Blip2ForConditionalGeneration,
#     AutoTokenizer,
#     AutoModelForCausalLM,
# )
# from PIL import Image
# import torch

# # Set device
# device = "cuda" if torch.cuda.is_available() else "cpu"

# # Load image captioning model (BLIP-2)
# processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
# blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)

# # Load text generation model (LLM)
# llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2").to(device)

# # Step 1: Generate image caption
# def extract_caption(image):
#     inputs = processor(images=image, return_tensors="pt").to(device)
#     outputs = blip_model.generate(**inputs, max_new_tokens=50)
#     caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
#     return caption

# # Step 2: Build fairytale prompt
# def build_prompt(caption):
#     return (
#         f"Based on the image description: \"{caption}\", write a children's fairytale.\n"
#         "The story must:\n"
#         "- Start with 'Once upon a time'\n"
#         "- Be at least 10 full sentences long\n"
#         "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n"
#         "- Avoid mentions of babies or unrelated royalty unless relevant\n"
#         "Here is the story:\nOnce upon a time"
#     )

# # Step 3: Generate story
# def generate_fairytale(image):
#     caption = extract_caption(image)
#     prompt = build_prompt(caption)
#     inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)

#     output = llm_model.generate(
#         **inputs,
#         max_new_tokens=500,
#         do_sample=True,
#         temperature=0.9,
#         top_p=0.95,
#         pad_token_id=llm_tokenizer.eos_token_id
#     )

#     result = llm_tokenizer.decode(output[0], skip_special_tokens=True)

#     # Trim to only the story
#     if "Once upon a time" in result:
#         return "Once upon a time" + result.split("Once upon a time", 1)[-1].strip()
#     else:
#         return f"⚠️ Failed to generate story.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}"

# # Gradio interface
# with gr.Blocks() as demo:
#     gr.Markdown("## 📖 AI Fairytale Generator\nUpload an image and get a magical story!")

#     with gr.Row():
#         image_input = gr.Image(type="pil", label="Upload an image")

#     with gr.Row():
#         generate_button = gr.Button("✨ Generate Fairytale")

#     with gr.Row():
#         output_text = gr.Textbox(label="Generated Story", lines=20)

#     generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[output_text])

# if __name__ == "__main__":
#     demo.launch(share=True)

import gradio as gr
from transformers import (
    Blip2Processor,
    Blip2ForConditionalGeneration,
    AutoTokenizer,
    AutoModelForCausalLM,
)
from PIL import Image
import torch

# 디바이스 설정 (Mistral은 CPU로 강제 설정)
device = "cpu"

# 1️⃣ 이미지 설명 생성 모델 로드 (BLIP-2)
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)

# 2️⃣ 동화 생성 모델 로드 (Mistral-7B, CPU)
llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="cpu")

# 3️⃣ 이미지 → 설명
def extract_caption(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    outputs = blip_model.generate(**inputs, max_new_tokens=50)
    caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    return caption

# 4️⃣ 프롬프트 구성
def build_prompt(caption):
    return (
        f"Based on the image description: \"{caption}\", write a children's fairytale.\n"
        "The story must:\n"
        "- Start with 'Once upon a time'\n"
        "- Be at least 10 full sentences long\n"
        "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n"
        "- Avoid unrelated royalty or babies unless relevant\n"
        "Here is the story:\nOnce upon a time"
    )

# 5️⃣ 전체 동화 생성
def generate_fairytale(image):
    caption = extract_caption(image)
    prompt = build_prompt(caption)

    inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)

    output = llm_model.generate(
        **inputs,
        max_new_tokens=500,
        do_sample=True,
        temperature=0.9,
        top_p=0.95,
        pad_token_id=llm_tokenizer.eos_token_id
    )

    result = llm_tokenizer.decode(output[0], skip_special_tokens=True)

    # "Once upon a time" 이후만 추출
    if "Once upon a time" in result:
        return caption, "Once upon a time" + result.split("Once upon a time", 1)[-1].strip()
    else:
        return caption, f"⚠️ Story generation failed.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}"

# 6️⃣ Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🧚 AI Fairytale Generator (Mistral CPU ver.)")
    gr.Markdown("Upload an image and receive a magical children's fairytale based on it ✨")

    image_input = gr.Image(type="pil", label="🖼️ Upload an Image")
    generate_button = gr.Button("✨ Generate Fairytale")

    caption_output = gr.Textbox(label="📌 Image Description", lines=2)
    story_output = gr.Textbox(label="📖 Generated Fairytale", lines=20)

    generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[caption_output, story_output])

if __name__ == "__main__":
    demo.launch(share=True)