|
|
import gradio as gr |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from torchvision.models import resnet18 |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
|
|
|
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] |
|
|
theme_color = "#6C5B7B" |
|
|
|
|
|
|
|
|
class Classifier(nn.Module): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.cnn_layers = resnet18(weights=None) |
|
|
self.fc_layers = nn.Sequential( |
|
|
nn.Linear(1000, 512), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(512, 128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, len(labels)) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.cnn_layers(x) |
|
|
x = self.fc_layers(x) |
|
|
return x |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224,224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485,0.456,0.406], |
|
|
std=[0.229,0.224,0.225]) |
|
|
]) |
|
|
|
|
|
model = Classifier() |
|
|
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
def predict(image_input): |
|
|
""" |
|
|
Supporta: |
|
|
- PIL Image (UI web) |
|
|
- stringa base64 (API) |
|
|
""" |
|
|
try: |
|
|
if isinstance(image_input, str): |
|
|
if image_input.startswith("data:image"): |
|
|
image_input = image_input.split(",",1)[1] |
|
|
img_bytes = base64.b64decode(image_input) |
|
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
else: |
|
|
img = image_input.convert("RGB") |
|
|
|
|
|
img_tensor = preprocess(img).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(img_tensor) |
|
|
probs = torch.nn.functional.softmax(logits[0], dim=0) |
|
|
|
|
|
probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))} |
|
|
max_label = max(probs_dict, key=probs_dict.get) |
|
|
|
|
|
return max_label, probs_dict |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error: {str(e)}", {} |
|
|
|
|
|
def clear_all(): |
|
|
return "", "" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="NSFW Image Classifier") as demo: |
|
|
|
|
|
gr.HTML(f""" |
|
|
<div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;"> |
|
|
<h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2> |
|
|
<p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
|
|
|
img_input = gr.Image(label="📷 Carica immagine", type="pil") |
|
|
base64_input = gr.Textbox( |
|
|
label="📤 Base64 dell'immagine (API)", |
|
|
lines=6, |
|
|
placeholder="Incolla qui la stringa base64..." |
|
|
) |
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("✨ Analizza", variant="primary") |
|
|
clear_btn = gr.Button("🔄 Pulisci", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
label_output = gr.Textbox(label="Classe predetta", interactive=False) |
|
|
result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels)) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=predict, |
|
|
inputs=[img_input], |
|
|
outputs=[label_output, result_display] |
|
|
) |
|
|
clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input]) |
|
|
|
|
|
|
|
|
api_button = gr.Button(visible=False) |
|
|
api_button.click( |
|
|
fn=predict, |
|
|
inputs=[base64_input], |
|
|
outputs=[label_output, result_display], |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True) |
|
|
|