Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig | |
| import torch | |
| model_id = "HuggingFaceM4/idefics2-8b" | |
| peft_model_id = "HuggingFaceH4/idefics2-8b-dpo-rlaif-v-v0.3" | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16 | |
| ) | |
| processor = AutoProcessor.from_pretrained(peft_model_id) | |
| model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16, quantization_config=quantization_config) | |
| model.load_adapter(peft_model_id) | |
| def respond(multimodal_input): | |
| images = multimodal_input["files"] | |
| content = [{"type": "image"} for _ in images] | |
| content.append({"type": "text", "text": multimodal_input["text"]}) | |
| messages = [{"role": "user", "content": content}] | |
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = processor(text=prompt, images=[images], return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| num_tokens = len(inputs["input_ids"][0]) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=500) | |
| new_tokens = generated_ids[:, num_tokens:] | |
| generated_text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0] | |
| return generated_text | |
| gr.Interface( | |
| respond, | |
| inputs=[gr.MultimodalTextbox(file_types=["image"], show_label=False)], | |
| outputs="text", | |
| title="IDEFICS2-8B DPO", | |
| description="Try IDEFICS2-8B fine-tuned using direct preference optimization (DPO) in this demo. Learn more about vision language model DPO integration of TRL [here](https://huggingface.co/blog/dpo_vlm).", | |
| examples=[ | |
| {"text": "What is the type of flower in the image and what insect is on it?", "files": ["./bee.jpg"]}, | |
| {"text": "Describe the image", "files": ["./howl.jpg"]}, | |
| ], | |
| ).launch() | |