| from transformers import LlamaForCausalLM, AutoTokenizer | |
| from transformers import BitsAndBytesConfig, LlamaForCausalLM, LlamaTokenizer | |
| from peft import PeftModel, PeftConfig | |
| import torch | |
| class EndpointHandler: | |
| def __init__(self, model_path="."): | |
| self.model = LlamaForCausalLM.from_pretrained(model_path, | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type='nf4' | |
| ) | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path, | |
| eos_token = "<|eot_id|>") | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| def __call__(self, request_data): | |
| prompt = request_data["prompt"] | |
| chat = [ | |
| {"role": "system", "content": "You are a helpful assistant"}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | |
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt") | |
| output = self.model.generate(input_ids, max_length=400) | |
| generated_text = self.tokenizer.decode(output[0], skip_special_tokens=False) | |
| generated_text = generated_text.replace(prompt,'').replace('<|begin_of_text|>', '').strip() | |
| return {"response": generated_text} |