import gradio as gr from transformers import AutoModel, AutoTokenizer from sklearn.cluster import KMeans from kneed import KneeLocator import torch import re import json # ===== Quote Extractor ===== def extract_dialogues_with_context_from_text(raw_text, context_window=1): lines = [line.strip() for line in raw_text.splitlines() if line.strip()] dialogue_data = [] for i, line in enumerate(lines): quotes = re.findall(r'[“"]([^”"]+)[”"]', line) for quote in quotes: context_lines = lines[max(0, i - context_window): i] + lines[i+1: i+1 + context_window] context = " ".join(context_lines) dialogue_data.append({ "quote": quote, "context": context, "line_index": i }) return dialogue_data # ===== Encoder ===== def encode_quote(context: str, dialogue: str, tokenizer, model) -> torch.Tensor: text = f"{context} [SEP] {dialogue}" inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ) outputs = model(**inputs) cls_embedding = outputs.last_hidden_state[:, 0, :] return cls_embedding.squeeze(0) def load_encoder(): tokenizer = AutoTokenizer.from_pretrained("aNameNobodyChose/quote-caster-encoder") model = AutoModel.from_pretrained("aNameNobodyChose/quote-caster-encoder") model.eval() return tokenizer, model def embed_quotes(data, tokenizer, model): embeddings = [] for ex in data: emb = encode_quote(ex["context"], ex["quote"], tokenizer, model) embeddings.append(emb) return torch.stack(embeddings) def auto_k_via_elbow(embeddings, max_k=10): X = embeddings.detach().numpy() inertias = [] for k in range(1, max_k + 1): kmeans = KMeans(n_clusters=k, random_state=42, n_init='auto') kmeans.fit(X) inertias.append(kmeans.inertia_) knee = KneeLocator(range(1, max_k + 1), inertias, curve="convex", direction="decreasing") return knee.knee or 2 # ===== Pipeline ===== def predict(story_text): try: data = extract_dialogues_with_context_from_text(story_text) if not data: return "❌ No quotes found in story. Make sure quotes are enclosed in double quotes (\")." tokenizer, model = load_encoder() embeddings = embed_quotes(data, tokenizer, model) k = auto_k_via_elbow(embeddings) labels = KMeans(n_clusters=k).fit_predict(embeddings.detach().numpy()) for quote, cluster_id in zip(data, labels): quote["predicted_speaker"] = f"SPEAKER_{cluster_id}" return json.dumps(data, indent=2, ensure_ascii=False) except Exception as e: return f"❌ Error: {e}" # ===== Gradio App ===== gr.Interface( fn=predict, inputs=gr.Textbox(lines=30, label="Paste full story text (with quotes in double-quotes)"), outputs="textbox", title="🗣️ QuoteCaster - Speaker Attribution from Raw Text", description="Paste a full story containing dialogue in double quotes. The model will extract, embed, and cluster quotes by speaker." ).launch(server_name="0.0.0.0", server_port=7860)