Siavassh-LLAMA4 / app.py
Siavassh's picture
Update app.py
706ac60 verified
raw
history blame contribute delete
892 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
app = FastAPI()
# Model and token
model_name = "meta-llama/Llama-4-11B-Instruct"
token = os.getenv("test") # use the secret you named in HF Space
# Load tokenizer + model
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=token
)
# Input schema
class InputText(BaseModel):
text: str
# API endpoint
@app.post("/predict")
def predict(item: InputText):
inputs = tokenizer(item.text, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=200)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"result": result}