# handler.py import io, os, base64, requests, torch from PIL import Image from transformers import AutoModelForCausalLM, BitsAndBytesConfig class EndpointHandler: def __init__(self, path=""): # Optionnel : quantif 4-bit via variable d'env load_4bit = os.getenv("LOAD_IN_4BIT", "0") == "1" qcfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) if load_4bit else None # ⚠️ Utiliser le path local fourni par l'endpoint # ⚠️ Forcer l'implémentation d'attention "eager" (pas SDPA, pas FlashAttention) self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, torch_dtype=torch.bfloat16, quantization_config=qcfg, device_map="auto", multimodal_max_length=int(os.getenv("MULTIMODAL_MAX_LENGTH", "8192")), attn_implementation="eager", llm_attn_implementation="eager", # ce champ est lu par certaines archis VLM ) self.txt_tok = self.model.get_text_tokenizer() self.vis_tok = self.model.get_visual_tokenizer() def _load_image(self, spec): if "url" in spec: r = requests.get(spec["url"], timeout=10); r.raise_for_status() return Image.open(io.BytesIO(r.content)).convert("RGB") if "base64" in spec: return Image.open(io.BytesIO(base64.b64decode(spec["base64"]))).convert("RGB") raise ValueError("image must have 'url' or 'base64'") def __call__(self, data): prompt = data.get("prompt", "") imgs_spec = data.get("images", []) max_new = int(data.get("max_new_tokens", 512)) images = [self._load_image(s) for s in imgs_spec] if images: prefix = "\n".join([""] * len(images)) query = f"{prefix}\n{prompt}" max_part = 4 if len(images) > 1 else 9 else: query, max_part = prompt, None prompt, input_ids, pix = self.model.preprocess_inputs(query, images, max_partition=max_part) attn = (input_ids != self.txt_tok.pad_token_id).unsqueeze(0).to(self.model.device) input_ids = input_ids.unsqueeze(0).to(self.model.device) pix = [pix.to(dtype=self.vis_tok.dtype, device=self.vis_tok.device)] if pix is not None else None with torch.inference_mode(): out_ids = self.model.generate( input_ids, pixel_values=pix, attention_mask=attn, max_new_tokens=max_new, do_sample=False, use_cache=True, eos_token_id=self.model.generation_config.eos_token_id, pad_token_id=self.txt_tok.pad_token_id, )[0] text = self.txt_tok.decode(out_ids, skip_special_tokens=True) return {"output": text}