Mostafa174's picture
triggering redeployment
f676a0c
raw
history blame
6.11 kB
import gradio as gr
import os
import numpy as np
from scipy.special import expit
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from PyPDF2 import PdfReader
from docx import Document
# Load Model and Tokenizer
MODEL = "cardiffnlp/tweet-topic-21-multi"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForSequenceClassification.from_pretrained(MODEL)
class_mapping = model.config.id2label
# Text Analyzer
def analyze_topics(text):
detected_topics = []
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
outputs = model(**inputs)
scores = outputs.logits[0].detach().numpy()
scores = expit(scores)
predictions = (scores >= 0.5).astype(int)
for i, pred in enumerate(predictions):
if pred:
topic_name = class_mapping[i]
confidence = scores[i]
detected_topics.append(f"• {topic_name} ({confidence:.2f})")
if detected_topics:
return "\n".join(detected_topics)
else:
return "No specific topics detected."
# Document Analyzer Helpers
def extract_text_from_file(file_path):
ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
reader = PdfReader(file_path)
text = " ".join([page.extract_text() for page in reader.pages if page.extract_text()])
elif ext == ".docx":
doc = Document(file_path)
text = "\n".join([p.text for p in doc.paragraphs])
elif ext == ".txt":
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
else:
raise ValueError("Unsupported file format. Please upload a PDF, DOCX, or TXT file.")
return text.strip()
def analyze_document(file):
if file is None:
return "Please upload a document first."
text = extract_text_from_file(file.name)
if not text:
return "No readable text found in document."
# Split into chunks for large docs
words = text.split()
chunks = [" ".join(words[i:i + 400]) for i in range(0, len(words), 400)]
all_detected_topics = {}
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
outputs = model(**inputs)
scores = outputs.logits[0].detach().numpy()
scores = expit(scores)
predictions = (scores >= 0.5).astype(int)
for i, pred in enumerate(predictions):
if pred:
topic_name = class_mapping[i]
confidence = scores[i]
all_detected_topics.setdefault(topic_name, []).append(confidence)
if all_detected_topics:
summary = [
f"• {topic} (avg confidence: {np.mean(confs):.2f})"
for topic, confs in all_detected_topics.items()
]
summary.sort(key=lambda x: float(x.split(': ')[-1].rstrip(')')), reverse=True)
return "\n".join(summary)
else:
return "No specific topics detected in document."
css = """
/* --- Global Layout --- */
body {
background-color: #1a1a1a !important;
color: #f5f5f5 !important;
font-family: 'Inter', sans-serif !important;
margin: 0 !important;
padding: 0 !important;
}
/* Full width */
#root, .gradio-container, .main {
max-width: 100% !important;
width: 100% !important;
background-color: #1a1a1a !important;
margin: 0 !important;
padding: 0 !important;
border: none !important;
box-shadow: none !important;
}
/* Headings and Labels */
h1, h2, h3, label {
color: #ff9900 !important;
font-weight: 600 !important;
}
/* Text Inputs */
textarea, input {
background-color: #2a2a2a !important;
color: #f5f5f5 !important;
border: 1px solid #3a3a3a !important;
border-radius: 10px !important;
padding: 12px !important;
}
/* Buttons */
button {
background-color: #ff9900 !important;
color: #1a1a1a !important;
font-weight: 600 !important;
border-radius: 8px !important;
border: none !important;
padding: 8px 16px !important;
transition: 0.25s ease-in-out;
}
button:hover {
background-color: #ffb84d !important;
}
/* Output textbox */
.output-textbox {
background-color: #252525 !important;
color: #ffd480 !important;
border: 1px solid #3a3a3a !important;
border-radius: 10px !important;
box-shadow: inset 0 0 6px rgba(255,153,0,0.1);
}
/* Tabs */
.tabitem.svelte-1ipelgc {
background-color: #1a1a1a !important;
color: #ffb84d !important;
}
.tabitem.svelte-1ipelgc.selected {
background-color: #ff9900 !important;
color: #1a1a1a !important;
font-weight: 700 !important;
}
/* Footer */
.footer, .svelte-1xdkkgx, .wrap.svelte-1ipelgc {
background: none !important;
border: none !important;
box-shadow: none !important;
color: #888 !important;
text-align: center !important;
}
"""
# -------------------------
# Gradio Interface
# -------------------------
tweet_tab = gr.Interface(
fn=analyze_topics,
inputs=gr.Textbox(
label="📝 Enter Text",
placeholder="Type or paste text here...",
lines=4
),
outputs=gr.Textbox(label="🎯 Detected Topics"),
examples=[
["Just watched the new Marvel movie, it was amazing!"],
["Bitcoin prices are going up again!"],
["Climate change is affecting polar bears."],
],
title="💬 Text Topic Analyzer",
description="Analyze short texts or tweets to detect underlying topics using CardiffNLP’s Tweet Topic model.",
)
document_tab = gr.Interface(
fn=analyze_document,
inputs=gr.File(label="📄 Upload Document (PDF, DOCX, or TXT)"),
outputs=gr.Textbox(label="📘 Detected Topics"),
title="📄 Document Topic Analyzer",
description="Upload a document and let the AI detect key topics discussed inside.",
)
app = gr.TabbedInterface(
[tweet_tab, document_tab],
["💬 Text Analyzer", "📄 Document Analyzer"],
title="🧠 AI Topic Analyzer",
css=css,
theme=gr.themes.Base(primary_hue="orange", secondary_hue="orange"),
)
if __name__ == "__main__":
app.launch()