Spaces:
Running
Running
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from typing import List | |
| from clearml import Model | |
| import torch | |
| from configs import add_args | |
| from models import build_or_load_gen_model | |
| import argparse | |
| from argparse import Namespace | |
| import os | |
| from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig | |
| MAX_SOURCE_LENGTH = 512 | |
| def pad_assert(tokenizer, source_ids): | |
| source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
| source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
| pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
| source_ids += [tokenizer.pad_id] * pad_len | |
| assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
| return source_ids | |
| # Encode code content and comment into model input | |
| def encode_diff(tokenizer, code, comment): | |
| # Tokenize code file content | |
| code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
| # Tokenize comment | |
| comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
| # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment | |
| source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id] | |
| source_ids += [tokenizer.msg_id] + comment_ids | |
| # Pad/truncate to fixed length | |
| source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
| source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
| pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
| source_ids += [tokenizer.pad_id] * pad_len | |
| assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
| return source_ids | |
| # Load base model architecture and tokenizer from HuggingFace | |
| BASE_MODEL_NAME = "microsoft/codereviewer" | |
| args = Namespace( | |
| model_name_or_path=BASE_MODEL_NAME, | |
| load_model_path=None, | |
| # Add other necessary default arguments if build_or_load_gen_model requires them | |
| ) | |
| print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}") | |
| config, base_model, tokenizer = build_or_load_gen_model(args) | |
| print("Base model architecture and tokenizer loaded.") | |
| # Download the fine-tuned weights from ClearML | |
| CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e" | |
| model_obj = Model(model_id=CLEARML_MODEL_ID) | |
| finetuned_weights_path = model_obj.get_local_copy() | |
| adapter_dir = os.path.dirname(finetuned_weights_path) | |
| print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}") | |
| # Create LoRA configuration matching the fine-tuned checkpoint | |
| lora_cfg = LoraConfig( | |
| r=64, | |
| lora_alpha=128, | |
| target_modules=["q", "wo", "wi", "v", "o", "k"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="SEQ_2_SEQ_LM" | |
| ) | |
| # Wrap base model with PEFT LoRA | |
| peft_model = get_peft_model(base_model, lora_cfg) | |
| # Load adapter-only weights and merge into base | |
| adapter_state = torch.load(finetuned_weights_path, map_location="cpu") | |
| peft_model.load_state_dict(adapter_state, strict=False) | |
| model = peft_model.merge_and_unload() | |
| print("Merged base model with LoRA adapters.") | |
| model.to("cpu") | |
| model.eval() | |
| print("Model ready for inference.") | |
| app = FastAPI() | |
| last_payload = {"comment": "", "files": []} | |
| last_infer_result = {"generated_code": ""} | |
| class FileContent(BaseModel): | |
| filename: str | |
| content: str | |
| class PRPayload(BaseModel): | |
| comment: str | |
| files: List[FileContent] | |
| class InferenceRequest(BaseModel): | |
| comment: str | |
| files: List[FileContent] | |
| def root(): | |
| return {"message": "FastAPI PR comment service is running"} | |
| async def receive_pr_comment(payload: PRPayload): | |
| global last_payload | |
| last_payload = payload.dict() | |
| # Return the received payload as JSON and also redirect to /show | |
| return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"}) | |
| def show_last_comment(): | |
| html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>" | |
| for file in last_payload["files"]: | |
| html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>" | |
| return html | |
| async def infer(request: InferenceRequest): | |
| global last_infer_result | |
| print("[DEBUG] Received /infer request with:", request.dict()) | |
| code = request.files[0].content if request.files else "" | |
| source_ids = encode_diff(tokenizer, code, request.comment) | |
| # print("[DEBUG] source_ids:", source_ids) | |
| #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids] | |
| #print("[DEBUG] tokens:", tokens) | |
| inputs = torch.tensor([source_ids], dtype=torch.long) | |
| inputs_mask = inputs.ne(tokenizer.pad_id) | |
| preds = model.generate( | |
| inputs, | |
| attention_mask=inputs_mask, | |
| use_cache=True, | |
| num_beams=5, | |
| early_stopping=True, | |
| max_length=100, | |
| num_return_sequences=1 | |
| ) | |
| pred = preds[0].cpu().numpy() | |
| pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| last_infer_result = {"generated_code": pred_nl} | |
| return last_infer_result | |
| def show_infer_result(): | |
| html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>" | |
| return html | |
| if __name__ == "__main__": | |
| # Place any CLI/training logic here if needed | |
| # This block is NOT executed when running with uvicorn | |
| pass |