poetist / app.py
Acrosoc's picture
Update app.py
769df9b verified
raw
history blame
4.84 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import transformers
import charactertokenizer # Импортируем новый токенизатор
import os
# --- 1. Настройка приложения и модели ---
# Определяем устройство. Для бесплатных HF Spaces это всегда 'cpu'.
# Использование os.environ.get для гибкости, если вы переключитесь на GPU.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MODEL_NAME = 'ai-forever/charllama-2.6B'
# Инициализация FastAPI приложения
app = FastAPI(
title="CharLLaMA 2.6B API",
description="API для генерации текста с использованием модели ai-forever/charllama-2.6B",
version="1.0.0"
)
# --- 2. Загрузка модели и токенизатора ---
# Глобальные переменные для модели и токенизатора
model = None
tokenizer = None
# Обернем загрузку в try-except для отлова ошибок при старте
try:
print(f"Загрузка токенизатора {MODEL_NAME}...")
tokenizer = charactertokenizer.CharacterTokenizer.from_pretrained(MODEL_NAME)
print(f"Загрузка модели {MODEL_NAME} на устройство {DEVICE}...")
# Для CPU-инстанций используем torch.float32. Если бы была GPU, можно было бы использовать float16
model = transformers.AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32
)
model.to(DEVICE)
print("Модель и токенизатор успешно загружены.")
except Exception as e:
print(f"Критическая ошибка при загрузке модели: {e}")
# Если модель не загрузилась, приложение будет возвращать ошибку.
# --- 3. Определение моделей данных (Pydantic) ---
class GenerationInput(BaseModel):
prompt: str
max_length: int = 512 # Даем пользователю возможность управлять параметрами
temperature: float = 0.8
top_p: float = 0.6
# --- 4. Создание эндпоинтов API ---
@app.get("/")
def read_root():
"""Корневой эндпоинт для проверки работоспособности."""
return {"status": "API is running", "model_loaded": model is not None}
@app.post("/generate")
def generate_text(request: GenerationInput):
"""
Эндпоинт для генерации текста.
Принимает JSON с полем 'prompt' и опциональными параметрами генерации.
"""
if not model or not tokenizer:
raise HTTPException(
status_code=503,
detail="Модель не была загружена. Проверьте логи Space."
)
prompt = request.prompt
# Параметры генерации из запроса и примера
generation_args = {
'max_length': request.max_length,
'num_return_sequences': 1,
'do_sample': True,
'no_repeat_ngram_size': 10,
'temperature': request.temperature,
'top_p': request.top_p,
'top_k': 0,
}
try:
# 1. Токенизация входного текста
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE)
prompt_len = input_ids.shape[1]
# 2. Генерация
print("Начинаю генерацию...")
output_ids = model.generate(
input_ids=input_ids,
eos_token_id=tokenizer.eos_token_id,
**generation_args
)
print("Генерация завершена.")
# 3. Декодирование и постобработка
# Декодируем только сгенерированную часть, исключая исходный промпт
generated_part = output_ids[0][prompt_len:]
output_text = tokenizer.decode(generated_part, skip_special_tokens=True)
# Убираем все, что идет после токена конца последовательности, если он есть
if '</s>' in output_text:
output_text = output_text.split('</s>')[0].strip()
return {
"input_prompt": prompt,
"generated_text": output_text
}
except Exception as e:
print(f"Ошибка во время генерации: {e}")
raise HTTPException(status_code=500, detail=str(e))