NeoChen1024's picture
Upload quantize.py with huggingface_hub
f3295d2 verified
import re
import requests
import torch
from PIL import Image
from transformers import AutoProcessor
from datasets import load_dataset
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers.tracing import TraceableLlavaForConditionalGeneration
# Load model.
model_id = "llama-joycaption-beta-one-hf-llava"
model = TraceableLlavaForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype="bfloat16"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = "test"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
PROMPT = "Write a long descriptive caption for this image in a formal tone."
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
ds = ds.shuffle(seed=42)
def preprocess_function(example):
# Build the conversation
convo = [
{
"role": "system",
"content": "You are a helpful image captioner.",
},
{
"role": "user",
"content": PROMPT,
},
{"role": "assistant", "content": " ".join(example["caption"])},
]
# Format the conversation
# WARNING: HF's handling of chat's on Llava models is very fragile. This specific combination of processor.apply_chat_template(), and processor() works
# but if using other combinations always inspect the final input_ids to ensure they are correct. Often times you will end up with multiple <bos> tokens
# if not careful, which can make the model perform poorly.
convo_string = processor.apply_chat_template(
convo, tokenize=False, add_generation_prompt=True
)
assert isinstance(convo_string, str)
# Process the inputs
inputs = processor(
text=[convo_string], images=[example["image"]], return_tensors="pt"
).to("cuda")
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
return inputs
ds = ds.map(preprocess_function)
# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
# return {key: torch.tensor(value) for key, value in batch[0].items()}
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
}
# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W8A8",
sequential_targets=["LlamaDecoderLayer"],
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
),
]
SAVE_DIR = model_id + "-W8A8"
# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=ds,
splits={"calibration": f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]"},
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
output_dir=SAVE_DIR,
)
# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(images=[raw_image], text=prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0], skip_special_tokens=True))
print("==========================================")