gamin commited on
Commit
2c99f9b
·
1 Parent(s): a2e0332

Mistral-7B-Instruct-v0.2 cpu 사용(6000s)

Browse files
Files changed (1) hide show
  1. Mistral-7B-Instruct-v0.2.py +160 -0
Mistral-7B-Instruct-v0.2.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import gradio as gr
2
+ # from transformers import (
3
+ # Blip2Processor,
4
+ # Blip2ForConditionalGeneration,
5
+ # AutoTokenizer,
6
+ # AutoModelForCausalLM,
7
+ # )
8
+ # from PIL import Image
9
+ # import torch
10
+
11
+ # # Set device
12
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # # Load image captioning model (BLIP-2)
15
+ # processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
16
+ # blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
17
+
18
+ # # Load text generation model (LLM)
19
+ # llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
20
+ # llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2").to(device)
21
+
22
+ # # Step 1: Generate image caption
23
+ # def extract_caption(image):
24
+ # inputs = processor(images=image, return_tensors="pt").to(device)
25
+ # outputs = blip_model.generate(**inputs, max_new_tokens=50)
26
+ # caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+ # return caption
28
+
29
+ # # Step 2: Build fairytale prompt
30
+ # def build_prompt(caption):
31
+ # return (
32
+ # f"Based on the image description: \"{caption}\", write a children's fairytale.\n"
33
+ # "The story must:\n"
34
+ # "- Start with 'Once upon a time'\n"
35
+ # "- Be at least 10 full sentences long\n"
36
+ # "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n"
37
+ # "- Avoid mentions of babies or unrelated royalty unless relevant\n"
38
+ # "Here is the story:\nOnce upon a time"
39
+ # )
40
+
41
+ # # Step 3: Generate story
42
+ # def generate_fairytale(image):
43
+ # caption = extract_caption(image)
44
+ # prompt = build_prompt(caption)
45
+ # inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
46
+
47
+ # output = llm_model.generate(
48
+ # **inputs,
49
+ # max_new_tokens=500,
50
+ # do_sample=True,
51
+ # temperature=0.9,
52
+ # top_p=0.95,
53
+ # pad_token_id=llm_tokenizer.eos_token_id
54
+ # )
55
+
56
+ # result = llm_tokenizer.decode(output[0], skip_special_tokens=True)
57
+
58
+ # # Trim to only the story
59
+ # if "Once upon a time" in result:
60
+ # return "Once upon a time" + result.split("Once upon a time", 1)[-1].strip()
61
+ # else:
62
+ # return f"⚠️ Failed to generate story.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}"
63
+
64
+ # # Gradio interface
65
+ # with gr.Blocks() as demo:
66
+ # gr.Markdown("## 📖 AI Fairytale Generator\nUpload an image and get a magical story!")
67
+
68
+ # with gr.Row():
69
+ # image_input = gr.Image(type="pil", label="Upload an image")
70
+
71
+ # with gr.Row():
72
+ # generate_button = gr.Button("✨ Generate Fairytale")
73
+
74
+ # with gr.Row():
75
+ # output_text = gr.Textbox(label="Generated Story", lines=20)
76
+
77
+ # generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[output_text])
78
+
79
+ # if __name__ == "__main__":
80
+ # demo.launch(share=True)
81
+
82
+ import gradio as gr
83
+ from transformers import (
84
+ Blip2Processor,
85
+ Blip2ForConditionalGeneration,
86
+ AutoTokenizer,
87
+ AutoModelForCausalLM,
88
+ )
89
+ from PIL import Image
90
+ import torch
91
+
92
+ # 디바이스 설정 (Mistral은 CPU로 강제 설정)
93
+ device = "cpu"
94
+
95
+ # 1️⃣ 이미지 설명 생성 모델 로드 (BLIP-2)
96
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")
97
+ blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device)
98
+
99
+ # 2️⃣ 동화 생성 모델 로드 (Mistral-7B, CPU)
100
+ llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
101
+ llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="cpu")
102
+
103
+ # 3️⃣ 이미지 → 설명
104
+ def extract_caption(image):
105
+ inputs = processor(images=image, return_tensors="pt").to(device)
106
+ outputs = blip_model.generate(**inputs, max_new_tokens=50)
107
+ caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
108
+ return caption
109
+
110
+ # 4️⃣ 프롬프트 구성
111
+ def build_prompt(caption):
112
+ return (
113
+ f"Based on the image description: \"{caption}\", write a children's fairytale.\n"
114
+ "The story must:\n"
115
+ "- Start with 'Once upon a time'\n"
116
+ "- Be at least 10 full sentences long\n"
117
+ "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n"
118
+ "- Avoid unrelated royalty or babies unless relevant\n"
119
+ "Here is the story:\nOnce upon a time"
120
+ )
121
+
122
+ # 5️⃣ 전체 동화 생성
123
+ def generate_fairytale(image):
124
+ caption = extract_caption(image)
125
+ prompt = build_prompt(caption)
126
+
127
+ inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
128
+
129
+ output = llm_model.generate(
130
+ **inputs,
131
+ max_new_tokens=500,
132
+ do_sample=True,
133
+ temperature=0.9,
134
+ top_p=0.95,
135
+ pad_token_id=llm_tokenizer.eos_token_id
136
+ )
137
+
138
+ result = llm_tokenizer.decode(output[0], skip_special_tokens=True)
139
+
140
+ # "Once upon a time" 이후만 추출
141
+ if "Once upon a time" in result:
142
+ return caption, "Once upon a time" + result.split("Once upon a time", 1)[-1].strip()
143
+ else:
144
+ return caption, f"⚠️ Story generation failed.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}"
145
+
146
+ # 6️⃣ Gradio UI
147
+ with gr.Blocks() as demo:
148
+ gr.Markdown("## 🧚 AI Fairytale Generator (Mistral CPU ver.)")
149
+ gr.Markdown("Upload an image and receive a magical children's fairytale based on it ✨")
150
+
151
+ image_input = gr.Image(type="pil", label="🖼️ Upload an Image")
152
+ generate_button = gr.Button("✨ Generate Fairytale")
153
+
154
+ caption_output = gr.Textbox(label="📌 Image Description", lines=2)
155
+ story_output = gr.Textbox(label="📖 Generated Fairytale", lines=20)
156
+
157
+ generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[caption_output, story_output])
158
+
159
+ if __name__ == "__main__":
160
+ demo.launch(share=True)