mahoon-legal-ai / app.py
hajimammad's picture
Update app.py
9195200 verified
# -*- coding: utf-8 -*-
"""
Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B + ZeroGPU + Role Gating
- تب «مشاوره» برای همه تعاملی است.
- تب‌های «ایندکس»، «ساخت دیتاست»، «پاکسازی»، «آموزش»، «Weight Tuning» برای بازدیدکننده فقط نمایشی‌اند؛
و سمت‌سرور نیز گِیت نقش دارد (ادمین/بازدیدکننده).
پیش‌نیازها:
- golden_builder.py , weights_sweep.py
- Settings → Secrets: WANDB_API_KEY (در صورت استفاده از W&B)
- Settings → Environment Variables: ADMIN_USERS (مثلاً: haji-mammad, teammate1)
- requirements.txt (ZeroGPU-ready) شامل spaces>=0.42.0
"""
from __future__ import annotations
# --- Telemetry hard-off + ZeroGPU SDK (must be before chroma import) ---
import os, logging
os.environ["CHROMA_TELEMETRY_ENABLED"] = "false"
os.environ["ANONYMIZED_TELEMETRY"] = "false"
import spaces # ZeroGPU SDK
# (اختیاری) کاهش نویز لاگ‌ها
logging.getLogger("chromadb").setLevel(logging.ERROR)
logging.getLogger("posthog").setLevel(logging.CRITICAL)
# -----------------------------------------------------------------------
import sys, re, json, time, pickle, zipfile, warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Dict, Optional
import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import gradio as gr
warnings.filterwarnings("ignore")
# ====== Transformers ======
import transformers as tf
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
Trainer, TrainingArguments, EarlyStoppingCallback
)
# ====== RAG stack ======
import chromadb
from chromadb.config import Settings
from rank_bm25 import BM25Okapi
from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
# ---- Monkeypatch Chroma telemetry (fallback) ----
try:
import chromadb.telemetry as _ctel
try: _ctel.client = None
except Exception: pass
for _n in ("capture", "capture_event"):
if hasattr(_ctel, _n):
try: setattr(_ctel, _n, lambda *a, **k: None)
except Exception: pass
if hasattr(_ctel, "Telemetry"):
try: _ctel.Telemetry().capture = lambda *a, **k: None
except Exception: pass
except Exception:
pass
# -------------------------------------------------
# ========= Persian normalization =========
ZWNJ = "\u200c"
AR_DIGITS = "٠١٢٣٤٥٦٧٨٩"
FA_DIGITS = "۰۱۲۳۴۵۶۷۸۹"
EN_DIGITS = "0123456789"
def normalize_fa(s: str) -> str:
if not s:
return s
s = s.replace("\u064A", "ی").replace("\u0643", "ک")
s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s)
trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)}
s = s.translate(trans)
s = re.sub(r"\s*‌\s*", ZWNJ, s)
s = re.sub(r"\s+", " ", s).strip()
return s
# ==========================
# Configs
# ==========================
@dataclass
class ModelConfig:
model_name: str = "Qwen/Qwen2.5-7B-Instruct"
max_input_length: int = 3072
max_new_tokens: int = 256
temperature: float = 0.7
top_p: float = 0.9
do_sample: bool = True
gradient_checkpointing: bool = True
@dataclass
class RAGConfig:
persist_dir: str = "./chroma_db"
collection: str = "legal_articles"
top_k: int = 6
similarity_threshold: float = 0.68
context_char_limit: int = 260
enable: bool = True
reranker_name: str = "Alibaba-NLP/gte-multilingual-reranker-base"
@dataclass
class TrainConfig:
base_model: str = "PartAI/Dorna-Llama3-8B-Instruct"
alt_model_1: str = "zpm/Llama-3.1-PersianQA"
hakim_model: str = "AI-Hoosh/HAKIM-7B"
hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct"
output_dir: str = "./mahoon_causal_lora"
seed: int = 42
test_size: float = 0.1
epochs: int = 2
batch_size: int = 2
grad_accum: int = 4
lr: float = 2e-4
warmup_ratio: float = 0.03
weight_decay: float = 0.0
logging_steps: int = 50
eval_strategy: str = "epoch"
save_strategy: str = "epoch"
save_total_limit: int = 2
report_to: str = "wandb"
max_grad_norm: float = 1.0
use_4bit: bool = False
max_seq_len: int = 2048
@dataclass
class SystemConfig:
model: ModelConfig = field(default_factory=ModelConfig)
rag: RAGConfig = field(default_factory=RAGConfig)
train: TrainConfig = field(default_factory=TrainConfig)
# ==========================
# Helpers
# ==========================
def set_seed_all(seed: int = 42):
import random
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def bf16_supported():
return torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)()
def log_deps():
try:
import accelerate, datasets
print("[deps]",
f"python={sys.version.split()[0]}",
f"transformers={tf.__version__}",
f"accelerate={accelerate.__version__}",
f"datasets={datasets.__version__}",
f"gradio={gr.__version__}",
flush=True)
except Exception as e:
print("[deps] warn:", e, flush=True)
# ==========================
# Role gating helpers
# ==========================
def _get_username(request: gr.Request) -> str | None:
try:
return getattr(request, "username", None)
except Exception:
return None
def is_admin(request: gr.Request) -> bool:
uname = _get_username(request)
if not uname:
return False
author = os.getenv("SPACE_AUTHOR_NAME", "").strip()
allow = {u.strip() for u in os.getenv("ADMIN_USERS", "").split(",") if u.strip()}
return (uname == author) or (uname in allow)
# ==========================
# RAG: Chroma + BM25 + CrossEncoder reranker
# ==========================
class LegalRAG:
def __init__(self, cfg: RAGConfig):
self.cfg = cfg
self.client = None
self.collection = None
self.reranker: Optional[CrossEncoder] = None
self.bm25 = None
self.bm25_ids: List[str] = []
self.bm25_path = str(Path(self.cfg.persist_dir) / "bm25.pkl")
def init(self):
Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(
path=self.cfg.persist_dir,
settings=Settings(anonymized_telemetry=False)
)
try:
self.collection = self.client.get_or_create_collection(self.cfg.collection)
except Exception:
try: self.collection = self.client.get_collection(self.cfg.collection)
except Exception: self.collection = self.client.create_collection(self.cfg.collection)
try:
self.reranker = CrossEncoder(self.cfg.reranker_name, device="cpu")
except Exception:
self.reranker = None
if Path(self.bm25_path).exists():
with open(self.bm25_path, "rb") as f:
obj = pickle.load(f)
self.bm25 = obj["bm25"]; self.bm25_ids = obj["ids"]
def _rebuild_bm25(self, ids: List[str], docs: List[str]):
corpus = [normalize_fa(d).split() for d in docs]
self.bm25 = BM25Okapi(corpus)
self.bm25_ids = ids
with open(self.bm25_path, "wb") as f:
pickle.dump({"bm25": self.bm25, "ids": self.bm25_ids}, f)
def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"):
if not self.collection: self.init()
seen: Dict[str, int] = {}
ids, docs, metas = [], [], []
def _norm_id(x: str) -> str:
x = x or ""
x = x.replace("\u064A", "ی").replace("\u0643", "ک")
trans = {ord(a): e for a, e in zip("٠١٢٣٤٥٦٧٨٩۰۱۲۳۴۵۶۷۸۹", "01234567890123456789")}
x = x.translate(trans)
x = re.sub(r"\s+", "", x)
return x
with open(jsonl_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
s = line.strip()
if not s: continue
try: obj = json.loads(s)
except: continue
raw_id = str(obj.get(id_key, f"auto_{i}"))
base_id = _norm_id(raw_id)
txt = normalize_fa(str(obj.get(text_key, "")).strip())
if not txt: continue
if base_id in seen:
seen[base_id] += 1
uid = f"{base_id}__d{seen[base_id]}"
dupe_idx = seen[base_id]
else:
seen[base_id] = 1
uid = base_id
dupe_idx = 1
ids.append(uid); docs.append(txt); metas.append({"article_id": base_id, "dupe_idx": dupe_idx})
if not ids:
return "هیچ سندی برای ایندکس یافت نشد."
self.collection.upsert(ids=ids, documents=docs, metadatas=metas)
self._rebuild_bm25(ids, docs)
dup_count = sum(1 for _, c in seen.items() if c > 1)
return f"✅ {len(ids)} سند ایندکس شد (Dense+BM25). شناسه‌های تکراری: {dup_count} کلید (با پسوند __dN یکتا شدند)."
def retrieve(self, query: str) -> List[Dict]:
if not self.collection: return []
qn = normalize_fa(query)
# Dense
try:
res = self.collection.query(
query_texts=[qn],
n_results=max(self.cfg.top_k * 3, 20),
include=["documents", "metadatas", "distances"],
)
out = []
docs = res.get("documents", [[]])[0]
metas = res.get("metadatas", [[]])[0]
dists = res.get("distances", [[1.0]])[0]
for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)):
sim = 1.0 - float(dist)
out.append({"article_id": (meta or {}).get("article_id", f"unk_{i}"),
"text": doc, "similarity": sim})
except Exception:
out = []
# BM25
bm25_hits = []
if self.bm25 is not None and self.bm25_ids:
scores = self.bm25.get_scores(normalize_fa(qn).split())
idxs = np.argsort(scores)[::-1][:max(self.cfg.top_k * 3, 20)]
smax = float(scores.max() + 1e-8)
for j in idxs:
aid = self.bm25_ids[int(j)]
try:
got = self.collection.get(ids=[aid])
tdoc = got["documents"][0]
except Exception:
tdoc = ""
bm25_hits.append({"article_id": aid, "text": tdoc, "similarity": float(scores[j]) / smax})
# merge
pool: Dict[str, Dict] = {}
for a in out + bm25_hits:
if a["article_id"] not in pool or a.get("similarity", 0) > pool[a["article_id"]].get("similarity", 0):
pool[a["article_id"]] = a
merged = [a for a in pool.values() if a.get("text") and len(a["text"]) > 15]
merged = [a for a in merged if a.get("similarity", 0) >= self.cfg.similarity_threshold]
# rerank (GPU only during predict)
if merged and self.reranker:
pairs = [(qn, a["text"]) for a in merged]
try:
with spaces.GPU(duration=30):
scores = self.reranker.predict(pairs)
except Exception:
scores = self.reranker.predict(pairs)
for a, s in zip(merged, scores): a["score"] = float(s)
merged = sorted(merged, key=lambda x: x.get("score", 0), reverse=True)[: self.cfg.top_k]
else:
merged = sorted(merged, key=lambda x: x.get("similarity", 0), reverse=True)[: self.cfg.top_k]
return merged
def build_context(self, arts: List[Dict]) -> str:
if not arts: return ""
bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts]
return "مواد مرتبط:\n" + "\n".join(bullets)
# ========= RAG bootstrap from repo =========
def parse_law_textfile_to_jsonl(txt_path: str, out_jsonl: str):
pat = re.compile(r"(?:ماده|مادّه)\s+(\d+)\s*[:\-–]\s*(.+)")
rows = []
with open(txt_path, "r", encoding="utf-8") as f:
for line in f:
s = line.strip()
if not s: continue
m = pat.match(s)
if not m: continue
aid = m.group(1); body = m.group(2).strip()
if len(body) < 12: continue
rows.append({"article_id": aid, "text": normalize_fa(body)})
if not rows: raise RuntimeError("هیچ ماده‌ای با الگوی تعریف‌شده پیدا نشد.")
with open(out_jsonl, "w", encoding="utf-8") as g:
for r in rows: g.write(json.dumps(r, ensure_ascii=False) + "\n")
return len(rows)
def ensure_chroma_ready(persist_dir="./chroma_db", collection="legal_articles") -> str:
Path(persist_dir).mkdir(parents=True, exist_ok=True)
if any(Path(persist_dir).glob("*")):
return f"ChromaDB موجود است."
zip_path = Path("./chroma_legal_db.zip")
if zip_path.exists():
try:
with zipfile.ZipFile(zip_path, "r") as z: z.extractall(persist_dir)
return "ChromaDB از zip بازیابی شد."
except Exception: pass
txt_path = Path("./all_legal_sentences.txt")
if txt_path.exists():
n = parse_law_textfile_to_jsonl(str(txt_path), "./laws.jsonl")
rag_local = LegalRAG(RAGConfig(persist_dir=persist_dir, collection=collection))
rag_local.init()
msg = rag_local.index_jsonl("./laws.jsonl", id_key="article_id", text_key="text")
return f"از متن خام {n} رکورد استخراج شد. {msg}"
return "پایگاه RAG موجود نیست و منبع خامی هم برای ساخت پیدا نشد."
# ==========================
# Loader + Generator (Causal-only, ZeroGPU)
# ==========================
class CausalLoader:
def __init__(self, mcfg: ModelConfig):
self.cfg = mcfg
self.tokenizer = None
self.model = None
def load(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"):
self.tokenizer.pad_token = self.tokenizer.eos_token
try:
with spaces.GPU(duration=90):
kwargs = {"low_cpu_mem_usage": True}
if torch.cuda.is_available():
kwargs["device_map"] = "auto"
kwargs["torch_dtype"] = torch.bfloat16 if bf16_supported() else torch.float16
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"):
try: self.model.gradient_checkpointing_enable()
except Exception: pass
except Exception:
self.model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
return self
class Generator:
def __init__(self, loader: CausalLoader, mcfg: ModelConfig):
self.tk = loader.tokenizer
self.model = loader.model
self.cfg = mcfg
def generate(self, question: str, context: str = "", system_prompt: str = "You are a helpful Persian legal assistant.") -> str:
parts = []
if system_prompt: parts.append(f"<|system|>\n{system_prompt}")
if context: parts.append(f"<|system|>\nاز منابع زیر استفاده کن و استنادی پاسخ بده:\n{context}")
parts.append(f"<|user|>\n{question}")
prompt = "\n".join(parts) + "\n<|assistant|>\n"
enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
try:
with spaces.GPU(duration=60):
dev_model = next(self.model.parameters()).device if hasattr(self.model, "parameters") else "cpu"
inputs = {k: v.to(dev_model) for k, v in enc.items()}
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=self.cfg.max_new_tokens,
do_sample=self.cfg.do_sample,
temperature=self.cfg.temperature,
top_p=self.cfg.top_p,
pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
)
except Exception:
inputs = {k: v for k, v in enc.items()}
with torch.no_grad():
out = self.model.generate(
**inputs,
max_new_tokens=min(self.cfg.max_new_tokens, 256),
do_sample=self.cfg.do_sample,
temperature=self.cfg.temperature,
top_p=self.cfg.top_p,
pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
)
return self.tk.decode(out[0], skip_special_tokens=True)
# ==========================
# Datasets & Trainer (Causal-only, W&B)
# ==========================
def read_jsonl_files(paths: List[str]) -> List[Dict]:
data: List[Dict] = []
for p in paths:
if not p: continue
with open(p, 'r', encoding='utf-8') as f:
for line in f:
s = line.strip()
if not s: continue
try: data.append(json.loads(s))
except json.JSONDecodeError: continue
return data
class CausalJSONLDataset(Dataset):
def __init__(self, data: List[Dict], tokenizer, max_len: int, rag: Optional[LegalRAG] = None, enhance_every:int = 8):
self.tk = tokenizer
self.max_len = max_len
self.items = []
for i, ex in enumerate(data):
src = normalize_fa(str(ex.get("input", "")).strip())
tgt = normalize_fa(str(ex.get("output", "")).strip())
if not src or not tgt: continue
ctx = ""
if rag and i % enhance_every == 0:
arts = rag.retrieve(src)
ctx = rag.build_context(arts)
text = ""
if ctx: text += f"<|system|>\nاز منابع زیر استفاده کن:\n{ctx}\n"
text += f"<|system|>\nYou are a helpful Persian legal assistant.\n"
text += f"<|user|>\n{src}\n<|assistant|>\n{tgt}"
self.items.append(text)
def __len__(self): return len(self.items)
def __getitem__(self, idx):
text = self.items[idx]
enc = self.tk(text, max_length=self.max_len, padding="max_length", truncation=True)
input_ids = torch.tensor(enc["input_ids"])
attn = torch.tensor(enc["attention_mask"])
labels = input_ids.clone(); labels[attn == 0] = -100
return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}
def safe_training_args(**kwargs):
return TrainingArguments(**kwargs)
class TrainerManager:
def __init__(self, syscfg: SystemConfig, loader: CausalLoader):
self.cfg = syscfg
self.loader = loader
def train_causal(self, train_paths: List[str], use_rag: bool = True, use_wandb: bool = True,
wandb_project: str = "mahoon-legal-ai", wandb_entity: str = "", run_name: str = "mahoon_causal_lora"):
set_seed_all(self.cfg.train.seed)
data = read_jsonl_files(train_paths)
train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None
if rag: rag.init()
ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.train.max_seq_len, rag)
ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.train.max_seq_len, None)
fp16_ok = torch.cuda.is_available() and not bf16_supported()
bf16_ok = bf16_supported()
if use_wandb:
os.environ.setdefault("WANDB_PROJECT", wandb_project or "mahoon-legal-ai")
if wandb_entity: os.environ.setdefault("WANDB_ENTITY", wandb_entity)
os.environ.pop("WANDB_DISABLED", None)
else:
os.environ["WANDB_DISABLED"] = "true"
args = safe_training_args(
output_dir=self.cfg.train.output_dir,
num_train_epochs=self.cfg.train.epochs,
learning_rate=self.cfg.train.lr,
per_device_train_batch_size=self.cfg.train.batch_size,
per_device_eval_batch_size=self.cfg.train.batch_size,
gradient_accumulation_steps=self.cfg.train.grad_accum,
warmup_ratio=self.cfg.train.warmup_ratio,
weight_decay=self.cfg.train.weight_decay,
evaluation_strategy=self.cfg.train.eval_strategy,
save_strategy=self.cfg.train.save_strategy,
save_total_limit=self.cfg.train.save_total_limit,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
logging_steps=self.cfg.train.logging_steps,
report_to=(["wandb"] if use_wandb else ["none"]),
run_name=run_name,
fp16=fp16_ok, bf16=bf16_ok,
max_grad_norm=self.cfg.train.max_grad_norm,
)
callbacks = [EarlyStoppingCallback(early_stopping_patience=2)]
try:
if use_wandb:
from transformers.integrations import WandbCallback
callbacks.append(WandbCallback())
except Exception:
pass
trainer = Trainer(
model=self.loader.model,
args=args,
train_dataset=ds_tr,
eval_dataset=ds_va,
tokenizer=self.loader.tokenizer,
callbacks=callbacks,
)
if use_wandb:
try:
import wandb
wandb.init(project=os.getenv("WANDB_PROJECT", "mahoon-legal-ai"),
entity=os.getenv("WANDB_ENTITY"),
name=run_name,
config={
"base_model": self.loader.model.name_or_path,
"epochs": self.cfg.train.epochs,
"batch": self.cfg.train.batch_size,
"grad_accum": self.cfg.train.grad_accum,
"lr": self.cfg.train.lr,
"max_seq_len": self.cfg.train.max_seq_len,
"use_rag": use_rag,
})
except Exception:
pass
trainer.train()
trainer.save_model(self.cfg.train.output_dir)
self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
if use_wandb:
try:
import wandb
art = wandb.Artifact("mahoon-model", type="model")
art.add_dir(self.cfg.train.output_dir)
wandb.log_artifact(art)
wandb.finish()
except Exception:
pass
# ==========================
# Dataset utilities (Cleaner/Deduper)
# ==========================
def deduplicate_jsonl(in_path: str, out_path: str, sim_threshold: float = 0.90, text_keys=("input","output")) -> int:
rows = []
with open(in_path, "r", encoding="utf-8") as f:
for line in f:
s = line.strip()
if not s: continue
try: obj = json.loads(s)
except: continue
for k in text_keys:
if k in obj: obj[k] = normalize_fa(str(obj[k]))
rows.append(obj)
if not rows: raise RuntimeError("هیچ رکورد معتبری در ورودی نبود.")
model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
embs = model.encode([r.get("input","") for r in rows], convert_to_tensor=True, show_progress_bar=False, normalize_embeddings=True)
kept, seen = [], torch.zeros(len(rows), dtype=torch.bool)
for i in range(len(rows)):
if seen[i]: continue
sims = st_util.cos_sim(embs[i], embs)[0]
dup_idx = (sims >= sim_threshold).nonzero(as_tuple=True)[0].tolist()
for j in dup_idx: seen[j] = True
kept.append(rows[i])
with open(out_path, "w", encoding="utf-8") as g:
for r in kept: g.write(json.dumps(r, ensure_ascii=False) + "\n")
return len(kept)
# ==========================
# App (Gradio) + Role Gating
# ==========================
class LegalApp:
def __init__(self, scfg: Optional[SystemConfig] = None):
self.scfg = scfg or SystemConfig()
self.rag = LegalRAG(self.scfg.rag)
self.loader: Optional[CausalLoader] = None
self.gen: Optional[Generator] = None
def _file_paths(self, files: List[gr.File]) -> List[str]:
paths = []
for f in (files or []):
p = getattr(f, "name", None) or getattr(f, "path", None)
if p: paths.append(p)
return paths
# Core (مشاوره/لود آزاد است)
def load(self, model_name: str):
self.loader = CausalLoader(self.scfg.model).load(model_name)
self.gen = Generator(self.loader, self.scfg.model)
# RAG
msg_rag = "RAG غیرفعال"
if self.scfg.rag.enable:
try:
self.rag = LegalRAG(self.scfg.rag); self.rag.init()
msg_rag = "RAG آماده است"
except Exception as e:
msg_rag = f"RAG خطا: {e}"
return f"مدل بارگذاری شد: {model_name}\n{msg_rag}"
# --- گیت سمت‌سرور: فقط ادمین ---
def build_index(self, laws_file: gr.File, id_key: str, text_key: str, request: gr.Request):
if not is_admin(request):
return "🔒 این عملیات فقط برای مدیران فعال است."
if not self.scfg.rag.enable: return "RAG غیرفعال است."
try:
self.rag.init()
p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None)
if not p: return "فایل قوانین معتبر نیست."
return self.rag.index_jsonl(p, id_key=id_key, text_key=text_key)
except Exception as e:
return f"خطا در ایندکس: {e}"
def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None, request: gr.Request):
if not is_admin(request):
return None, "🔒 این عملیات فقط برای مدیران فعال است."
try:
from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
except Exception as e:
return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}"
path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
if not path: return None, "⚠️ فایل ورودی معتبر نیست."
try:
data = load_json_or_jsonl(path)
if max_samples and int(max_samples) > 0: data = data[:int(max_samples)]
gb = GoldenBuilder(model_name=model_ckpt)
rows = gb.build(data, text_key=text_key, batch_size=int(batch_size))
out_dir = "/tmp/mahoon_datasets"; Path(out_dir).mkdir(parents=True, exist_ok=True)
out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl"
save_jsonl(rows, out_path)
return out_path, f"✅ {len(rows)} رکورد تولید شد."
except Exception as e:
return None, f"❌ خطا در ساخت دیتاست: {e}"
def train(self, model_name: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float,
use_wandb: bool, wandb_project: str, wandb_entity: str, run_name: str,
progress=gr.Progress(track_tqdm=True), request: gr.Request = None):
if not is_admin(request):
return "🔒 این عملیات فقط برای مدیران فعال است."
progress(0.05, desc="راه‌اندازی")
self.scfg.train.epochs = int(epochs)
self.scfg.train.batch_size = int(batch)
self.scfg.train.lr = float(lr)
progress(0.10, desc="بارگذاری مدل/توکنایزر")
self.loader = CausalLoader(self.scfg.model).load(model_name)
paths = self._file_paths(files)
if not paths: return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده."
tm = TrainerManager(self.scfg, self.loader)
set_seed_all(self.scfg.train.seed)
progress(0.30, desc="آماده‌سازی دیتاست‌ها و RAG (اختیاری)")
tm.train_causal(
paths, use_rag=use_rag, use_wandb=use_wandb,
wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name
)
progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها")
return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent, request: gr.Request):
if not is_admin(request):
return "🔒 این عملیات فقط برای مدیران فعال است."
p = getattr(f, "name", None) or getattr(f, "path", None)
if not p:
return "⚠️ فایل داده نامعتبر است."
try:
from weights_sweep import run_sweep
except Exception as e:
return f"❌ weights_sweep.py یافت نشد/قابل import نیست: {e}"
os.environ.setdefault("WANDB_PROJECT", proj or "mahoon-legal-ai")
if ent: os.environ.setdefault("WANDB_ENTITY", ent)
try:
run_sweep(data_path=p, text_key=tk, max_samples=int(ms), batch_size=int(bs),
project=proj, entity=ent, count=int(runs))
return "✅ Sweep اجرا شد. بهترین Run را در W&B بررسی و وزن‌ها را تثبیت کنید."
except Exception as e:
return f"❌ خطا در اجرای Sweep: {e}"
def apply_best_weights(self, wandb_project: str, wandb_entity: str, metric: str = "pass_rate", request: gr.Request = None):
if request is not None and not is_admin(request):
return "🔒 این عملیات فقط برای مدیران فعال است."
try:
import wandb, json as _json
except Exception as e:
return f"❌ W&B در محیط در دسترس نیست: {e}"
try:
api = wandb.Api()
proj_path = f"{wandb_entity}/{wandb_project}" if wandb_entity else wandb_project
runs = api.runs(proj_path, filters={"state": "finished"})
except Exception as e:
return f"❌ عدم دسترسی به پروژه W&B ({wandb_project}): {e}"
best_run = None; best_val = float("-inf")
for r in runs:
s = r.summary or {}
if "weights" in s and metric in s:
try: val = float(s[metric])
except Exception: continue
if val > best_val: best_val, best_run = val, r
if not best_run:
return "⚠️ هیچ Run واجد شرایطی با summary['weights'] و متریک موردنظر پیدا نشد."
weights = best_run.summary.get("weights", {})
if not isinstance(weights, dict) or not weights:
return "⚠️ فرمت وزن‌های بهترین Run نامعتبر است."
try:
with open("legal_entity_weights.json", "w", encoding="utf-8") as f:
_json.dump(weights, f, ensure_ascii=False, indent=2)
except Exception as e:
return f"❌ خطا در نوشتن legal_entity_weights.json: {e}"
rid = getattr(best_run, "id", "unknown")
return f"✅ وزن‌ها اعمال شد از Run `{rid}` با {metric}={best_val:.4f}. فایل: `legal_entity_weights.json`"
# Consultation (عمومی)
def answer(self, question: str, system_prompt: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float):
if not question.strip(): return "لطفاً سوال خود را وارد کنید.", ""
if not self.gen: return "ابتدا مدل را بارگذاری کنید.", ""
self.scfg.model.max_new_tokens = int(max_new_tokens)
self.scfg.model.temperature = float(temperature)
self.scfg.model.top_p = float(top_p)
arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else []
max_refs = 4
if arts: arts = arts[:max_refs]
ctx = self.rag.build_context(arts) if arts else ""
ans = self.gen.generate(question, ctx, system_prompt)
refs = ""
if arts:
refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a.get('similarity',0):.2f})\n{a['text'][:320]}..." for a in arts])
return ans, refs
# UI
def build_ui(self):
log_deps()
try:
print("[rag-bootstrap]", ensure_chroma_ready(self.scfg.rag.persist_dir, self.scfg.rag.collection), flush=True)
except Exception as e:
print("[rag-bootstrap] error:", e, flush=True)
default_gen_models = {
"Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct",
"Llama-3.1-8B Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Mistral-7B Instruct (v0.3)": "mistralai/Mistral-7B-Instruct-v0.3",
}
with gr.Blocks(title="ماحون — مشاور حقوقی (Causal-only, ZeroGPU)") as app:
# بنر نقش
role_banner = gr.Markdown()
gr.Markdown("""
<div style='text-align:center;padding:18px'>
<h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only, ZeroGPU)</h1>
<p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning</p>
</div>
""")
# --- Tab: Consultation (interactive for all) ---
with gr.Tab("مشاوره"):
with gr.Row():
gen_model_dd = gr.Dropdown(choices=list(default_gen_models.keys()), value="Qwen2.5-7B Instruct", label="مدل تولید")
gen_model_id = gr.Textbox(value=default_gen_models["Qwen2.5-7B Instruct"], label="Model ID (قابل ویرایش)")
with gr.Row():
use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟")
persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر ChromaDB")
collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن")
with gr.Row():
top_k = gr.Slider(1, 15, value=self.scfg.rag.top_k, step=1, label="Top-K")
threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="آستانه شباهت")
load_btn = gr.Button("بارگذاری مدل", variant="primary")
status = gr.Textbox(label="وضعیت", interactive=False)
with gr.Accordion("پارامترهای تولید", open=False):
system_prompt = gr.Textbox(value="You are a helpful Persian legal assistant.", label="System prompt")
max_new_tokens = gr.Slider(64, 2048, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens")
temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p")
question = gr.Textbox(lines=3, label="سوال حقوقی")
gr.Examples(
examples=[
["در صورت نقض قرارداد EPC چه راهکارهای حقوقی دارم؟"],
["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"],
["حق و حقوق کارگر در صورت اخراج فوری چیست؟"],
],
inputs=question, label="نمونه پرسش‌ها"
)
ask_btn = gr.Button("پرسش", variant="primary")
answer = gr.Markdown(label="پاسخ"); refs = gr.Markdown(label="مواد قانونی مرتبط")
# --- Tab: Indexing (view-only for visitors) ---
with gr.Tab("ایندکس قوانین"):
gr.Markdown("فایل JSONL قوانین را بارگذاری و ایندکس کنید (کلیدها: `article_id`, `text`).")
laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"])
id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده")
text_key = gr.Textbox(value="text", label="کلید متن ماده")
index_btn = gr.Button("ایندکس‌سازی قوانین"); index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False)
index_widgets = [laws_file, id_key, text_key, index_btn]
# --- Tab: Dataset Builder (view-only for visitors) ---
with gr.Tab("ساخت دیتاست"):
gr.Markdown("فایل خام (JSON/JSONL) → خروجی JSONL سازگار با `{input, output}` (از golden_builder).")
raw_file = gr.File(label="فایل خام", file_types=[".json",".jsonl"])
with gr.Row():
ds_text_key = gr.Textbox(value="متن_کامل", label="کلید متن (text_key)")
model_ckpt = gr.Dropdown(
choices=["google/mt5-base", "google/flan-t5-base", "t5-base"],
value="google/mt5-base",
label="مدل خلاصه‌ساز برای ساخت دیتاست (فقط Builder)"
)
with gr.Row():
ds_batch_size = gr.Slider(1, 16, value=4, step=1, label="Batch size")
max_samples = gr.Number(value=0, label="حداکثر نمونه (۰=همه)")
build_btn = gr.Button("ساخت دیتاست", variant="primary")
out_file = gr.File(label="دانلود خروجی JSONL", interactive=False)
build_status = gr.Textbox(label="وضعیت", interactive=False)
builder_widgets = [raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples, build_btn]
# --- Tab: Dataset Cleaning (view-only for visitors) ---
with gr.Tab("پاکسازی دیتاست"):
gr.Markdown("نرمال‌سازی فارسی + حذف تکراری‌های معنایی (cosine). ورودی: JSONL `{input, output}`.")
raw_ds = gr.File(label="JSONL ورودی", file_types=[".jsonl"])
sim_th = gr.Slider(0.80, 0.98, value=0.90, step=0.01, label="آستانه شباهت (cosine)")
clean_btn = gr.Button("اجرای پاکسازی", variant="primary")
cleaned_out = gr.File(label="دانلود JSONL پاک", interactive=False)
clean_status = gr.Markdown()
clean_widgets = [raw_ds, sim_th, clean_btn]
# --- Tab: Training (view-only for visitors) ---
with gr.Tab("آموزش"):
gr.Markdown("SFT/LoRA روی مدل‌های causal (فقط `{input, output}`) + W&B logging.")
with gr.Row():
model_train_dd = gr.Dropdown(
choices=[
"HAKIM (Editable ID below)",
"Hooshvareh (Editable ID below)",
"Dorna-Llama3-8B",
"PersianQA-8B",
"Custom (Editable ID below)"
],
value="HAKIM (Editable ID below)", label="پروفایل مدل"
)
model_train_id = gr.Textbox(value="AI-Hoosh/HAKIM-7B", label="HF Model ID (قابل ویرایش)")
use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training")
use_wandb = gr.Checkbox(value=True, label="W&B logging فعال باشد؟")
wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name")
gr.Markdown("راهنما: در Settings → Secrets مقدار `WANDB_API_KEY` را تنظیم کنید (مقدار واقعی).")
train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
with gr.Row():
epochs = gr.Slider(1, 6, value=2, step=1, label="epochs")
batch = gr.Slider(1, 8, value=2, step=1, label="batch per device")
lr = gr.Number(value=2e-4, label="learning rate")
train_btn = gr.Button("شروع آموزش", variant="primary")
train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
train_widgets = [model_train_dd, model_train_id, use_rag_train, use_wandb, wandb_project, wandb_entity,
run_name, train_files, epochs, batch, lr, train_btn]
# --- Tab: Weight Tuning (view-only for visitors) ---
with gr.Tab("Weight Tuning"):
gr.Markdown("تیون خودکار وزن‌های موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.")
tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"])
tune_text_key = gr.Textbox(value="متن_کامل", label="کلید متن")
tune_max_samples = gr.Slider(50, 400, value=120, step=10, label="حداکثر نمونه")
tune_runs = gr.Slider(4, 64, value=16, step=4, label="تعداد ران Sweep")
tune_batch = gr.Slider(1, 4, value=2, step=1, label="batch size Builder")
tune_proj = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
tune_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
run_tune = gr.Button("شروع Sweep", variant="primary")
tune_status = gr.Markdown()
gr.Markdown("---")
gr.Markdown("اعمال خودکار بهترین وزن‌ها از داشبورد W&B (بر اساس بالاترین `pass_rate`).")
metric_dd = gr.Dropdown(choices=["pass_rate"], value="pass_rate", label="متریک انتخاب بهترین Run")
apply_btn = gr.Button("اعمال بهترین وزن‌ها از W&B", variant="secondary")
tuning_widgets = [tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch,
tune_proj, tune_entity, run_tune, metric_dd, apply_btn]
# ---- Events (مشاوره آزاد / عملیاتِ ادمینی با گیت) ----
def _resolve_gen(choice: str, override: str) -> str:
return override.strip() if override.strip() else default_gen_models[choice]
def _on_load(choice, override, rag, pdir, coll, k, th):
self.scfg.rag.enable = bool(rag)
self.scfg.rag.persist_dir = pdir
self.scfg.rag.collection = coll
self.scfg.rag.top_k = int(k)
self.scfg.rag.similarity_threshold = float(th)
return self.load(_resolve_gen(choice, override))
def _whoami(request: gr.Request):
u = _get_username(request) or "Visitor"
return f"👤 کاربر: **{u}** — دسترسی: {'مدیریتی' if is_admin(request) else 'بازدیدکننده (فقط مشاهده)'}"
load_btn.click(_on_load,
inputs=[gen_model_dd, gen_model_id, use_rag, persist_dir, collection, top_k, threshold],
outputs=status)
ask_btn.click(self.answer,
inputs=[question, system_prompt, use_rag, max_new_tokens, temperature, top_p],
outputs=[answer, refs])
# ادمینی: استفاده از request injection (Gradio به‌طور خودکار تزریق می‌کند)
def _index_handler(f, ik, tk, request: gr.Request):
return self.build_index(f, ik, tk, request)
index_btn.click(_index_handler, inputs=[laws_file, id_key, text_key], outputs=index_status)
def _build_ds_handler(rf, tk, ckpt, bs, mx, request: gr.Request):
return self.build_dataset(rf, tk, ckpt, bs, mx, request)
build_btn.click(_build_ds_handler,
inputs=[raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples],
outputs=[out_file, build_status])
def _train_handler(prof, mid, files, rg, e, b, l, uw, wp, we, rn, request: gr.Request):
def _map_profile_to_id(profile: str, current_id: str) -> str:
if current_id.strip(): return current_id.strip()
if "Dorna" in profile: return "PartAI/Dorna-Llama3-8B-Instruct"
if "PersianQA" in profile: return "zpm/Llama-3.1-PersianQA"
if "HAKIM" in profile: return "AI-Hoosh/HAKIM-7B"
if "Hooshvareh" in profile: return "HooshvareLab/llama-fa-7b-instruct"
return "PartAI/Dorna-Llama3-8B-Instruct"
model_id = _map_profile_to_id(prof, mid)
return self.train(model_id, files, rg, e, b, l, uw, wp, we, rn, request=request)
train_btn.click(_train_handler,
inputs=[model_train_dd, model_train_id, train_files, use_rag_train, epochs, batch, lr,
use_wandb, wandb_project, wandb_entity, run_name],
outputs=train_status)
def _clean_handler(f, th):
p = getattr(f, "name", None) or getattr(f, "path", None)
if not p: return None, "⚠️ فایل نامعتبر."
outp = f"/tmp/cleaned_{int(time.time())}.jsonl"
n = deduplicate_jsonl(p, outp, sim_threshold=float(th))
return outp, f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{n}**"
clean_btn.click(_clean_handler, inputs=[raw_ds, sim_th], outputs=[cleaned_out, clean_status])
def _tune_handler(f, tk, ms, runs, bs, proj, ent, request: gr.Request):
return self.run_weight_tune(f, tk, ms, runs, bs, proj, ent, request)
run_tune.click(_tune_handler,
inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity],
outputs=tune_status)
def _apply_best_handler(proj, ent, m, request: gr.Request):
return self.apply_best_weights(proj, ent, m, request)
apply_btn.click(_apply_best_handler,
inputs=[tune_proj, tune_entity, metric_dd],
outputs=tune_status)
# --- Lock non-consultation tabs for visitors on load ---
def _gate_all(request: gr.Request):
admin = is_admin(request)
role_txt = f"👤 کاربر: **{_get_username(request) or 'Visitor'}** — دسترسی: {'مدیریتی' if admin else 'بازدیدکننده (فقط مشاهده)'}"
if not admin:
lock = gr.update(interactive=False)
updates = [lock] * (len(index_widgets) + len(builder_widgets) + len(clean_widgets) + len(train_widgets) + len(tuning_widgets))
else:
unlock = gr.update(interactive=True)
updates = [unlock] * (len(index_widgets) + len(builder_widgets) + len(clean_widgets) + len(train_widgets) + len(tuning_widgets))
return [role_txt] + updates
app.load(_whoami, inputs=None, outputs=role_banner)
app.load(_gate_all, inputs=None,
outputs=[role_banner] + index_widgets + builder_widgets + clean_widgets + train_widgets + tuning_widgets)
return app
# ==========================
# Entrypoint
# ==========================
if __name__ == "__main__":
app = LegalApp()
ui = app.build_ui()
try:
ui = ui.queue() # پایدار برای ZeroGPU
except TypeError:
pass
ui.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)