from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import transformers import charactertokenizer # Импортируем новый токенизатор import os # --- 1. Настройка приложения и модели --- # Определяем устройство. Для бесплатных HF Spaces это всегда 'cpu'. # Использование os.environ.get для гибкости, если вы переключитесь на GPU. DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" MODEL_NAME = 'ai-forever/charllama-2.6B' # Инициализация FastAPI приложения app = FastAPI( title="CharLLaMA 2.6B API", description="API для генерации текста с использованием модели ai-forever/charllama-2.6B", version="1.0.0" ) # --- 2. Загрузка модели и токенизатора --- # Глобальные переменные для модели и токенизатора model = None tokenizer = None # Обернем загрузку в try-except для отлова ошибок при старте try: print(f"Загрузка токенизатора {MODEL_NAME}...") tokenizer = charactertokenizer.CharacterTokenizer.from_pretrained(MODEL_NAME) print(f"Загрузка модели {MODEL_NAME} на устройство {DEVICE}...") # Для CPU-инстанций используем torch.float32. Если бы была GPU, можно было бы использовать float16 model = transformers.AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32 ) model.to(DEVICE) print("Модель и токенизатор успешно загружены.") except Exception as e: print(f"Критическая ошибка при загрузке модели: {e}") # Если модель не загрузилась, приложение будет возвращать ошибку. # --- 3. Определение моделей данных (Pydantic) --- class GenerationInput(BaseModel): prompt: str max_length: int = 512 # Даем пользователю возможность управлять параметрами temperature: float = 0.8 top_p: float = 0.6 # --- 4. Создание эндпоинтов API --- @app.get("/") def read_root(): """Корневой эндпоинт для проверки работоспособности.""" return {"status": "API is running", "model_loaded": model is not None} @app.post("/generate") def generate_text(request: GenerationInput): """ Эндпоинт для генерации текста. Принимает JSON с полем 'prompt' и опциональными параметрами генерации. """ if not model or not tokenizer: raise HTTPException( status_code=503, detail="Модель не была загружена. Проверьте логи Space." ) prompt = request.prompt # Параметры генерации из запроса и примера generation_args = { 'max_length': request.max_length, 'num_return_sequences': 1, 'do_sample': True, 'no_repeat_ngram_size': 10, 'temperature': request.temperature, 'top_p': request.top_p, 'top_k': 0, } try: # 1. Токенизация входного текста input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(DEVICE) prompt_len = input_ids.shape[1] # 2. Генерация print("Начинаю генерацию...") output_ids = model.generate( input_ids=input_ids, eos_token_id=tokenizer.eos_token_id, **generation_args ) print("Генерация завершена.") # 3. Декодирование и постобработка # Декодируем только сгенерированную часть, исключая исходный промпт generated_part = output_ids[0][prompt_len:] output_text = tokenizer.decode(generated_part, skip_special_tokens=True) # Убираем все, что идет после токена конца последовательности, если он есть if '' in output_text: output_text = output_text.split('')[0].strip() return { "input_prompt": prompt, "generated_text": output_text } except Exception as e: print(f"Ошибка во время генерации: {e}") raise HTTPException(status_code=500, detail=str(e))