|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
import transformers |
|
|
import charactertokenizer |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
MODEL_NAME = 'ai-forever/charllama-2.6B' |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="CharLLaMA 2.6B API", |
|
|
description="API для генерации текста с использованием модели ai-forever/charllama-2.6B", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
try: |
|
|
print(f"Загрузка токенизатора {MODEL_NAME}...") |
|
|
tokenizer = charactertokenizer.CharacterTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
print(f"Загрузка модели {MODEL_NAME} на устройство {DEVICE}...") |
|
|
|
|
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
model.to(DEVICE) |
|
|
print("Модель и токенизатор успешно загружены.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Критическая ошибка при загрузке модели: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenerationInput(BaseModel): |
|
|
prompt: str |
|
|
max_length: int = 512 |
|
|
temperature: float = 0.8 |
|
|
top_p: float = 0.6 |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
"""Корневой эндпоинт для проверки работоспособности.""" |
|
|
return {"status": "API is running", "model_loaded": model is not None} |
|
|
|
|
|
@app.post("/generate") |
|
|
def generate_text(request: GenerationInput): |
|
|
""" |
|
|
Эндпоинт для генерации текста. |
|
|
Принимает JSON с полем 'prompt' и опциональными параметрами генерации. |
|
|
""" |
|
|
if not model or not tokenizer: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail="Модель не была загружена. Проверьте логи Space." |
|
|
) |
|
|
|
|
|
prompt = request.prompt |
|
|
|
|
|
|
|
|
generation_args = { |
|
|
'max_length': request.max_length, |
|
|
'num_return_sequences': 1, |
|
|
'do_sample': True, |
|
|
'no_repeat_ngram_size': 10, |
|
|
'temperature': request.temperature, |
|
|
'top_p': request.top_p, |
|
|
'top_k': 0, |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE) |
|
|
prompt_len = input_ids.shape[1] |
|
|
|
|
|
|
|
|
print("Начинаю генерацию...") |
|
|
output_ids = model.generate( |
|
|
input_ids=input_ids, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
**generation_args |
|
|
) |
|
|
print("Генерация завершена.") |
|
|
|
|
|
|
|
|
|
|
|
generated_part = output_ids[0][prompt_len:] |
|
|
output_text = tokenizer.decode(generated_part, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if '</s>' in output_text: |
|
|
output_text = output_text.split('</s>')[0].strip() |
|
|
|
|
|
return { |
|
|
"input_prompt": prompt, |
|
|
"generated_text": output_text |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Ошибка во время генерации: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |