exampleone / app.py
kouki321's picture
Update app.py
eefa3ef verified
import os
import torch
import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException, Body
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache , StaticCache
from pydantic import BaseModel
from typing import Optional
import tempfile
from time import time
from fastapi.responses import RedirectResponse
# Add necessary serialization safety
torch.serialization.add_safe_globals([DynamicCache])
torch.serialization.add_safe_globals([set])
def generate(model, input_ids, past_key_values, max_new_tokens=50):
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:]
def get_kv_cache(model, tokenizer, prompt):
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache()
with torch.no_grad():
_ = model(
input_ids=input_ids,
past_key_values=cache,
use_cache=True
)
return cache, input_ids.shape[-1]
def clean_up(cache, origin_len):
new_cache = DynamicCache()
for i in range(len(cache.key_cache)):
new_cache.key_cache.append(cache.key_cache[i].clone())
new_cache.value_cache.append(cache.value_cache[i].clone())
for i in range(len(new_cache.key_cache)):
new_cache.key_cache[i] = new_cache.key_cache[i][:, :, :origin_len, :]
new_cache.value_cache[i] = new_cache.value_cache[i][:, :, :origin_len, :]
return new_cache
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ["HF_HUB_OFFLINE"] = "1"
def load_model_and_tokenizer():
model_path = os.environ.get("MODEL_PATH", "./model") # allow override via Docker env
tokenizer = AutoTokenizer.from_pretrained(model_path)
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto"
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
return model, tokenizer
app = FastAPI(title="DeepSeek QA with KV Cache API")
cache_store = {}
model, tokenizer = load_model_and_tokenizer()
class QueryRequest(BaseModel):
query: str
max_new_tokens: Optional[int] = 150
def clean_response(response_text):
import re
assistant_pattern = re.compile(r'<\|assistant\|>\s*(.*?)(?:<\/\|assistant\|>|<\|user\|>|<\|system\|>)', re.DOTALL)
matches = assistant_pattern.findall(response_text)
if matches:
for match in matches:
cleaned = match.strip()
if cleaned and not cleaned.startswith("<|") and len(cleaned) > 5:
return cleaned
cleaned = re.sub(r'<\|.*?\|>', '', response_text)
cleaned = re.sub(r'<\/\|.*?\|>', '', cleaned)
lines = cleaned.strip().split('\n')
unique_lines = []
for line in lines:
line = line.strip()
if line and line not in unique_lines:
unique_lines.append(line)
result = '\n'.join(unique_lines)
result = re.sub(r'<\/?\|.*?\|>\s*$', '', result)
return result.strip()
@app.post("/upload-document_to_create_KV_cache")
async def upload_document(file: UploadFile = File(...)):
t1 = time()
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
with open(temp_file_path, "r", encoding="utf-8") as f:
doc_text = f.read()
system_prompt = f"""
<|system|>
Answer concisely and precisely, You are an assistant who provides concise factual answers.
<|user|>
Context:
{doc_text}
Question:
""".strip()
cache, origin_len = get_kv_cache(model, tokenizer, system_prompt)
cache_id = f"cache_{int(time())}"
cache_store[cache_id] = {
"cache": cache,
"origin_len": origin_len,
"doc_preview": doc_text[:500] + "..." if len(doc_text) > 500 else doc_text
}
os.unlink(temp_file_path)
t2 = time()
return {
"cache_id": cache_id,
"message": "Document uploaded and cache created successfully",
"doc_preview": cache_store[cache_id]["doc_preview"],
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error processing document: {str(e)}")
@app.post("/generate_answer_from_cache/{cache_id}")
async def generate_answer(cache_id: str, request: QueryRequest):
t1 = time()
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
current_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
full_prompt = f"""
<|user|>
Question: {request.query}
<|assistant|>
""".strip()
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids
output_ids = generate(model, input_ids, current_cache, max_new_tokens=request.max_new_tokens)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
rep = clean_response(response)
t2 = time()
return {
"query": request.query,
"answer": rep,
"time_taken": f"{t2 - t1:.4f} seconds"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating answer: {str(e)}")
@app.post("/save_cache/{cache_id}")
async def save_cache(cache_id: str):
if cache_id not in cache_store:
raise HTTPException(status_code=404, detail="Document not found. Please upload it first.")
try:
cleaned_cache = clean_up(
cache_store[cache_id]["cache"],
cache_store[cache_id]["origin_len"]
)
cache_path = f"{cache_id}_cache.pth"
torch.save(cleaned_cache, cache_path)
return {
"message": f"Cache saved successfully as {cache_path}",
"cache_path": cache_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error saving cache: {str(e)}")
@app.post("/load_cache")
async def load_cache(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False, suffix=".pth") as temp_file:
temp_file_path = temp_file.name
content = await file.read()
temp_file.write(content)
try:
loaded_cache = torch.load(temp_file_path)
cache_id = f"loaded_cache_{int(time())}"
cache_store[cache_id] = {
"cache": loaded_cache,
"origin_len": loaded_cache.key_cache[0].shape[-2],
"doc_preview": "Loaded from cache file"
}
os.unlink(temp_file_path)
return {
"cache_id": cache_id,
"message": "Cache loaded successfully"
}
except Exception as e:
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)
raise HTTPException(status_code=500, detail=f"Error loading cache: {str(e)}")
@app.get("/list_of_caches")
async def list_documents():
documents = {}
for cache_id in cache_store:
documents[cache_id] = {
"doc_preview": cache_store[cache_id]["doc_preview"],
"origin_len": cache_store[cache_id]["origin_len"]
}
return {"documents": documents}
@app.get("/", include_in_schema=False)
async def root():
return RedirectResponse(url="/docs")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)