nsfw / app.py
ALSv's picture
Update app.py
a04766c verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
import base64
import io
# ---------------- CONFIG ----------------
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
theme_color = "#6C5B7B"
# ---------------- MODEL ----------------
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.cnn_layers = resnet18(weights=None)
self.fc_layers = nn.Sequential(
nn.Linear(1000, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, len(labels))
)
def forward(self, x):
x = self.cnn_layers(x)
x = self.fc_layers(x)
return x
preprocess = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()
# ---------------- FUNZIONE ----------------
def predict(image_input):
"""
Supporta:
- PIL Image (UI web)
- stringa base64 (API)
"""
try:
if isinstance(image_input, str):
if image_input.startswith("data:image"):
image_input = image_input.split(",",1)[1]
img_bytes = base64.b64decode(image_input)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
else:
img = image_input.convert("RGB")
img_tensor = preprocess(img).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.nn.functional.softmax(logits[0], dim=0)
probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
max_label = max(probs_dict, key=probs_dict.get)
return max_label, probs_dict
except Exception as e:
return f"Error: {str(e)}", {}
def clear_all():
return "", ""
# ---------------- INTERFACCIA ----------------
with gr.Blocks(title="NSFW Image Classifier") as demo:
gr.HTML(f"""
<div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
<h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
<p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p>
</div>
""")
with gr.Row():
with gr.Column(scale=2):
# Input UI
img_input = gr.Image(label="📷 Carica immagine", type="pil")
base64_input = gr.Textbox(
label="📤 Base64 dell'immagine (API)",
lines=6,
placeholder="Incolla qui la stringa base64..."
)
with gr.Row():
submit_btn = gr.Button("✨ Analizza", variant="primary")
clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
with gr.Column(scale=1):
label_output = gr.Textbox(label="Classe predetta", interactive=False)
result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
# ---------------- Eventi UI ----------------
submit_btn.click(
fn=predict,
inputs=[img_input],
outputs=[label_output, result_display]
)
clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input])
# ---------------- Pulsante invisibile per API base64 ----------------
api_button = gr.Button(visible=False)
api_button.click(
fn=predict,
inputs=[base64_input],
outputs=[label_output, result_display],
api_name="predict" # espone /run/predict
)
# ---------------- LAUNCH ----------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)