Spaces:
Runtime error
Runtime error
| # import gradio as gr | |
| # from transformers import ( | |
| # Blip2Processor, | |
| # Blip2ForConditionalGeneration, | |
| # AutoTokenizer, | |
| # AutoModelForCausalLM, | |
| # ) | |
| # from PIL import Image | |
| # import torch | |
| # # Set device | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # Load image captioning model (BLIP-2) | |
| # processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| # blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device) | |
| # # Load text generation model (LLM) | |
| # llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
| # llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2").to(device) | |
| # # Step 1: Generate image caption | |
| # def extract_caption(image): | |
| # inputs = processor(images=image, return_tensors="pt").to(device) | |
| # outputs = blip_model.generate(**inputs, max_new_tokens=50) | |
| # caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # return caption | |
| # # Step 2: Build fairytale prompt | |
| # def build_prompt(caption): | |
| # return ( | |
| # f"Based on the image description: \"{caption}\", write a children's fairytale.\n" | |
| # "The story must:\n" | |
| # "- Start with 'Once upon a time'\n" | |
| # "- Be at least 10 full sentences long\n" | |
| # "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n" | |
| # "- Avoid mentions of babies or unrelated royalty unless relevant\n" | |
| # "Here is the story:\nOnce upon a time" | |
| # ) | |
| # # Step 3: Generate story | |
| # def generate_fairytale(image): | |
| # caption = extract_caption(image) | |
| # prompt = build_prompt(caption) | |
| # inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
| # output = llm_model.generate( | |
| # **inputs, | |
| # max_new_tokens=500, | |
| # do_sample=True, | |
| # temperature=0.9, | |
| # top_p=0.95, | |
| # pad_token_id=llm_tokenizer.eos_token_id | |
| # ) | |
| # result = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
| # # Trim to only the story | |
| # if "Once upon a time" in result: | |
| # return "Once upon a time" + result.split("Once upon a time", 1)[-1].strip() | |
| # else: | |
| # return f"⚠️ Failed to generate story.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}" | |
| # # Gradio interface | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("## 📖 AI Fairytale Generator\nUpload an image and get a magical story!") | |
| # with gr.Row(): | |
| # image_input = gr.Image(type="pil", label="Upload an image") | |
| # with gr.Row(): | |
| # generate_button = gr.Button("✨ Generate Fairytale") | |
| # with gr.Row(): | |
| # output_text = gr.Textbox(label="Generated Story", lines=20) | |
| # generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[output_text]) | |
| # if __name__ == "__main__": | |
| # demo.launch(share=True) | |
| import gradio as gr | |
| from transformers import ( | |
| Blip2Processor, | |
| Blip2ForConditionalGeneration, | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| ) | |
| from PIL import Image | |
| import torch | |
| # 디바이스 설정 (Mistral은 CPU로 강제 설정) | |
| device = "cpu" | |
| # 1️⃣ 이미지 설명 생성 모델 로드 (BLIP-2) | |
| processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl") | |
| blip_model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-flan-t5-xl").to(device) | |
| # 2️⃣ 동화 생성 모델 로드 (Mistral-7B, CPU) | |
| llm_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
| llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", device_map="cpu") | |
| # 3️⃣ 이미지 → 설명 | |
| def extract_caption(image): | |
| inputs = processor(images=image, return_tensors="pt").to(device) | |
| outputs = blip_model.generate(**inputs, max_new_tokens=50) | |
| caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| return caption | |
| # 4️⃣ 프롬프트 구성 | |
| def build_prompt(caption): | |
| return ( | |
| f"Based on the image description: \"{caption}\", write a children's fairytale.\n" | |
| "The story must:\n" | |
| "- Start with 'Once upon a time'\n" | |
| "- Be at least 10 full sentences long\n" | |
| "- Include named characters, a clear setting, emotions, a challenge, and a resolution\n" | |
| "- Avoid unrelated royalty or babies unless relevant\n" | |
| "Here is the story:\nOnce upon a time" | |
| ) | |
| # 5️⃣ 전체 동화 생성 | |
| def generate_fairytale(image): | |
| caption = extract_caption(image) | |
| prompt = build_prompt(caption) | |
| inputs = llm_tokenizer(prompt, return_tensors="pt").to(device) | |
| output = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=500, | |
| do_sample=True, | |
| temperature=0.9, | |
| top_p=0.95, | |
| pad_token_id=llm_tokenizer.eos_token_id | |
| ) | |
| result = llm_tokenizer.decode(output[0], skip_special_tokens=True) | |
| # "Once upon a time" 이후만 추출 | |
| if "Once upon a time" in result: | |
| return caption, "Once upon a time" + result.split("Once upon a time", 1)[-1].strip() | |
| else: | |
| return caption, f"⚠️ Story generation failed.\n\n[Prompt]\n{prompt}\n\n[Output]\n{result}" | |
| # 6️⃣ Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🧚 AI Fairytale Generator (Mistral CPU ver.)") | |
| gr.Markdown("Upload an image and receive a magical children's fairytale based on it ✨") | |
| image_input = gr.Image(type="pil", label="🖼️ Upload an Image") | |
| generate_button = gr.Button("✨ Generate Fairytale") | |
| caption_output = gr.Textbox(label="📌 Image Description", lines=2) | |
| story_output = gr.Textbox(label="📖 Generated Fairytale", lines=20) | |
| generate_button.click(fn=generate_fairytale, inputs=[image_input], outputs=[caption_output, story_output]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |