# -*- coding: utf-8 -*- """ Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B + ZeroGPU + Role Gating - تب «مشاوره» برای همه تعاملی است. - تب‌های «ایندکس»، «ساخت دیتاست»، «پاکسازی»، «آموزش»، «Weight Tuning» برای بازدیدکننده فقط نمایشی‌اند؛ و سمت‌سرور نیز گِیت نقش دارد (ادمین/بازدیدکننده). پیش‌نیازها: - golden_builder.py , weights_sweep.py - Settings → Secrets: WANDB_API_KEY (در صورت استفاده از W&B) - Settings → Environment Variables: ADMIN_USERS (مثلاً: haji-mammad, teammate1) - requirements.txt (ZeroGPU-ready) شامل spaces>=0.42.0 """ from __future__ import annotations # --- Telemetry hard-off + ZeroGPU SDK (must be before chroma import) --- import os, logging os.environ["CHROMA_TELEMETRY_ENABLED"] = "false" os.environ["ANONYMIZED_TELEMETRY"] = "false" import spaces # ZeroGPU SDK # (اختیاری) کاهش نویز لاگ‌ها logging.getLogger("chromadb").setLevel(logging.ERROR) logging.getLogger("posthog").setLevel(logging.CRITICAL) # ----------------------------------------------------------------------- import sys, re, json, time, pickle, zipfile, warnings from dataclasses import dataclass, field from pathlib import Path from typing import List, Dict, Optional import numpy as np import torch from torch.utils.data import Dataset from sklearn.model_selection import train_test_split import gradio as gr warnings.filterwarnings("ignore") # ====== Transformers ====== import transformers as tf from transformers import ( AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, EarlyStoppingCallback ) # ====== RAG stack ====== import chromadb from chromadb.config import Settings from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util # ---- Monkeypatch Chroma telemetry (fallback) ---- try: import chromadb.telemetry as _ctel try: _ctel.client = None except Exception: pass for _n in ("capture", "capture_event"): if hasattr(_ctel, _n): try: setattr(_ctel, _n, lambda *a, **k: None) except Exception: pass if hasattr(_ctel, "Telemetry"): try: _ctel.Telemetry().capture = lambda *a, **k: None except Exception: pass except Exception: pass # ------------------------------------------------- # ========= Persian normalization ========= ZWNJ = "\u200c" AR_DIGITS = "٠١٢٣٤٥٦٧٨٩" FA_DIGITS = "۰۱۲۳۴۵۶۷۸۹" EN_DIGITS = "0123456789" def normalize_fa(s: str) -> str: if not s: return s s = s.replace("\u064A", "ی").replace("\u0643", "ک") s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)} s = s.translate(trans) s = re.sub(r"\s*‌\s*", ZWNJ, s) s = re.sub(r"\s+", " ", s).strip() return s # ========================== # Configs # ========================== @dataclass class ModelConfig: model_name: str = "Qwen/Qwen2.5-7B-Instruct" max_input_length: int = 3072 max_new_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 do_sample: bool = True gradient_checkpointing: bool = True @dataclass class RAGConfig: persist_dir: str = "./chroma_db" collection: str = "legal_articles" top_k: int = 6 similarity_threshold: float = 0.68 context_char_limit: int = 260 enable: bool = True reranker_name: str = "Alibaba-NLP/gte-multilingual-reranker-base" @dataclass class TrainConfig: base_model: str = "PartAI/Dorna-Llama3-8B-Instruct" alt_model_1: str = "zpm/Llama-3.1-PersianQA" hakim_model: str = "AI-Hoosh/HAKIM-7B" hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct" output_dir: str = "./mahoon_causal_lora" seed: int = 42 test_size: float = 0.1 epochs: int = 2 batch_size: int = 2 grad_accum: int = 4 lr: float = 2e-4 warmup_ratio: float = 0.03 weight_decay: float = 0.0 logging_steps: int = 50 eval_strategy: str = "epoch" save_strategy: str = "epoch" save_total_limit: int = 2 report_to: str = "wandb" max_grad_norm: float = 1.0 use_4bit: bool = False max_seq_len: int = 2048 @dataclass class SystemConfig: model: ModelConfig = field(default_factory=ModelConfig) rag: RAGConfig = field(default_factory=RAGConfig) train: TrainConfig = field(default_factory=TrainConfig) # ========================== # Helpers # ========================== def set_seed_all(seed: int = 42): import random random.seed(seed); np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def bf16_supported(): return torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)() def log_deps(): try: import accelerate, datasets print("[deps]", f"python={sys.version.split()[0]}", f"transformers={tf.__version__}", f"accelerate={accelerate.__version__}", f"datasets={datasets.__version__}", f"gradio={gr.__version__}", flush=True) except Exception as e: print("[deps] warn:", e, flush=True) # ========================== # Role gating helpers # ========================== def _get_username(request: gr.Request) -> str | None: try: return getattr(request, "username", None) except Exception: return None def is_admin(request: gr.Request) -> bool: uname = _get_username(request) if not uname: return False author = os.getenv("SPACE_AUTHOR_NAME", "").strip() allow = {u.strip() for u in os.getenv("ADMIN_USERS", "").split(",") if u.strip()} return (uname == author) or (uname in allow) # ========================== # RAG: Chroma + BM25 + CrossEncoder reranker # ========================== class LegalRAG: def __init__(self, cfg: RAGConfig): self.cfg = cfg self.client = None self.collection = None self.reranker: Optional[CrossEncoder] = None self.bm25 = None self.bm25_ids: List[str] = [] self.bm25_path = str(Path(self.cfg.persist_dir) / "bm25.pkl") def init(self): Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True) self.client = chromadb.PersistentClient( path=self.cfg.persist_dir, settings=Settings(anonymized_telemetry=False) ) try: self.collection = self.client.get_or_create_collection(self.cfg.collection) except Exception: try: self.collection = self.client.get_collection(self.cfg.collection) except Exception: self.collection = self.client.create_collection(self.cfg.collection) try: self.reranker = CrossEncoder(self.cfg.reranker_name, device="cpu") except Exception: self.reranker = None if Path(self.bm25_path).exists(): with open(self.bm25_path, "rb") as f: obj = pickle.load(f) self.bm25 = obj["bm25"]; self.bm25_ids = obj["ids"] def _rebuild_bm25(self, ids: List[str], docs: List[str]): corpus = [normalize_fa(d).split() for d in docs] self.bm25 = BM25Okapi(corpus) self.bm25_ids = ids with open(self.bm25_path, "wb") as f: pickle.dump({"bm25": self.bm25, "ids": self.bm25_ids}, f) def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"): if not self.collection: self.init() seen: Dict[str, int] = {} ids, docs, metas = [], [], [] def _norm_id(x: str) -> str: x = x or "" x = x.replace("\u064A", "ی").replace("\u0643", "ک") trans = {ord(a): e for a, e in zip("٠١٢٣٤٥٦٧٨٩۰۱۲۳۴۵۶۷۸۹", "01234567890123456789")} x = x.translate(trans) x = re.sub(r"\s+", "", x) return x with open(jsonl_path, "r", encoding="utf-8") as f: for i, line in enumerate(f): s = line.strip() if not s: continue try: obj = json.loads(s) except: continue raw_id = str(obj.get(id_key, f"auto_{i}")) base_id = _norm_id(raw_id) txt = normalize_fa(str(obj.get(text_key, "")).strip()) if not txt: continue if base_id in seen: seen[base_id] += 1 uid = f"{base_id}__d{seen[base_id]}" dupe_idx = seen[base_id] else: seen[base_id] = 1 uid = base_id dupe_idx = 1 ids.append(uid); docs.append(txt); metas.append({"article_id": base_id, "dupe_idx": dupe_idx}) if not ids: return "هیچ سندی برای ایندکس یافت نشد." self.collection.upsert(ids=ids, documents=docs, metadatas=metas) self._rebuild_bm25(ids, docs) dup_count = sum(1 for _, c in seen.items() if c > 1) return f"✅ {len(ids)} سند ایندکس شد (Dense+BM25). شناسه‌های تکراری: {dup_count} کلید (با پسوند __dN یکتا شدند)." def retrieve(self, query: str) -> List[Dict]: if not self.collection: return [] qn = normalize_fa(query) # Dense try: res = self.collection.query( query_texts=[qn], n_results=max(self.cfg.top_k * 3, 20), include=["documents", "metadatas", "distances"], ) out = [] docs = res.get("documents", [[]])[0] metas = res.get("metadatas", [[]])[0] dists = res.get("distances", [[1.0]])[0] for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)): sim = 1.0 - float(dist) out.append({"article_id": (meta or {}).get("article_id", f"unk_{i}"), "text": doc, "similarity": sim}) except Exception: out = [] # BM25 bm25_hits = [] if self.bm25 is not None and self.bm25_ids: scores = self.bm25.get_scores(normalize_fa(qn).split()) idxs = np.argsort(scores)[::-1][:max(self.cfg.top_k * 3, 20)] smax = float(scores.max() + 1e-8) for j in idxs: aid = self.bm25_ids[int(j)] try: got = self.collection.get(ids=[aid]) tdoc = got["documents"][0] except Exception: tdoc = "" bm25_hits.append({"article_id": aid, "text": tdoc, "similarity": float(scores[j]) / smax}) # merge pool: Dict[str, Dict] = {} for a in out + bm25_hits: if a["article_id"] not in pool or a.get("similarity", 0) > pool[a["article_id"]].get("similarity", 0): pool[a["article_id"]] = a merged = [a for a in pool.values() if a.get("text") and len(a["text"]) > 15] merged = [a for a in merged if a.get("similarity", 0) >= self.cfg.similarity_threshold] # rerank (GPU only during predict) if merged and self.reranker: pairs = [(qn, a["text"]) for a in merged] try: with spaces.GPU(duration=30): scores = self.reranker.predict(pairs) except Exception: scores = self.reranker.predict(pairs) for a, s in zip(merged, scores): a["score"] = float(s) merged = sorted(merged, key=lambda x: x.get("score", 0), reverse=True)[: self.cfg.top_k] else: merged = sorted(merged, key=lambda x: x.get("similarity", 0), reverse=True)[: self.cfg.top_k] return merged def build_context(self, arts: List[Dict]) -> str: if not arts: return "" bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts] return "مواد مرتبط:\n" + "\n".join(bullets) # ========= RAG bootstrap from repo ========= def parse_law_textfile_to_jsonl(txt_path: str, out_jsonl: str): pat = re.compile(r"(?:ماده|مادّه)\s+(\d+)\s*[:\-–]\s*(.+)") rows = [] with open(txt_path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if not s: continue m = pat.match(s) if not m: continue aid = m.group(1); body = m.group(2).strip() if len(body) < 12: continue rows.append({"article_id": aid, "text": normalize_fa(body)}) if not rows: raise RuntimeError("هیچ ماده‌ای با الگوی تعریف‌شده پیدا نشد.") with open(out_jsonl, "w", encoding="utf-8") as g: for r in rows: g.write(json.dumps(r, ensure_ascii=False) + "\n") return len(rows) def ensure_chroma_ready(persist_dir="./chroma_db", collection="legal_articles") -> str: Path(persist_dir).mkdir(parents=True, exist_ok=True) if any(Path(persist_dir).glob("*")): return f"ChromaDB موجود است." zip_path = Path("./chroma_legal_db.zip") if zip_path.exists(): try: with zipfile.ZipFile(zip_path, "r") as z: z.extractall(persist_dir) return "ChromaDB از zip بازیابی شد." except Exception: pass txt_path = Path("./all_legal_sentences.txt") if txt_path.exists(): n = parse_law_textfile_to_jsonl(str(txt_path), "./laws.jsonl") rag_local = LegalRAG(RAGConfig(persist_dir=persist_dir, collection=collection)) rag_local.init() msg = rag_local.index_jsonl("./laws.jsonl", id_key="article_id", text_key="text") return f"از متن خام {n} رکورد استخراج شد. {msg}" return "پایگاه RAG موجود نیست و منبع خامی هم برای ساخت پیدا نشد." # ========================== # Loader + Generator (Causal-only, ZeroGPU) # ========================== class CausalLoader: def __init__(self, mcfg: ModelConfig): self.cfg = mcfg self.tokenizer = None self.model = None def load(self, model_name: str): self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"): self.tokenizer.pad_token = self.tokenizer.eos_token try: with spaces.GPU(duration=90): kwargs = {"low_cpu_mem_usage": True} if torch.cuda.is_available(): kwargs["device_map"] = "auto" kwargs["torch_dtype"] = torch.bfloat16 if bf16_supported() else torch.float16 self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"): try: self.model.gradient_checkpointing_enable() except Exception: pass except Exception: self.model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True) return self class Generator: def __init__(self, loader: CausalLoader, mcfg: ModelConfig): self.tk = loader.tokenizer self.model = loader.model self.cfg = mcfg def generate(self, question: str, context: str = "", system_prompt: str = "You are a helpful Persian legal assistant.") -> str: parts = [] if system_prompt: parts.append(f"<|system|>\n{system_prompt}") if context: parts.append(f"<|system|>\nاز منابع زیر استفاده کن و استنادی پاسخ بده:\n{context}") parts.append(f"<|user|>\n{question}") prompt = "\n".join(parts) + "\n<|assistant|>\n" enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length) try: with spaces.GPU(duration=60): dev_model = next(self.model.parameters()).device if hasattr(self.model, "parameters") else "cpu" inputs = {k: v.to(dev_model) for k, v in enc.items()} with torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=self.cfg.max_new_tokens, do_sample=self.cfg.do_sample, temperature=self.cfg.temperature, top_p=self.cfg.top_p, pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, ) except Exception: inputs = {k: v for k, v in enc.items()} with torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=min(self.cfg.max_new_tokens, 256), do_sample=self.cfg.do_sample, temperature=self.cfg.temperature, top_p=self.cfg.top_p, pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, ) return self.tk.decode(out[0], skip_special_tokens=True) # ========================== # Datasets & Trainer (Causal-only, W&B) # ========================== def read_jsonl_files(paths: List[str]) -> List[Dict]: data: List[Dict] = [] for p in paths: if not p: continue with open(p, 'r', encoding='utf-8') as f: for line in f: s = line.strip() if not s: continue try: data.append(json.loads(s)) except json.JSONDecodeError: continue return data class CausalJSONLDataset(Dataset): def __init__(self, data: List[Dict], tokenizer, max_len: int, rag: Optional[LegalRAG] = None, enhance_every:int = 8): self.tk = tokenizer self.max_len = max_len self.items = [] for i, ex in enumerate(data): src = normalize_fa(str(ex.get("input", "")).strip()) tgt = normalize_fa(str(ex.get("output", "")).strip()) if not src or not tgt: continue ctx = "" if rag and i % enhance_every == 0: arts = rag.retrieve(src) ctx = rag.build_context(arts) text = "" if ctx: text += f"<|system|>\nاز منابع زیر استفاده کن:\n{ctx}\n" text += f"<|system|>\nYou are a helpful Persian legal assistant.\n" text += f"<|user|>\n{src}\n<|assistant|>\n{tgt}" self.items.append(text) def __len__(self): return len(self.items) def __getitem__(self, idx): text = self.items[idx] enc = self.tk(text, max_length=self.max_len, padding="max_length", truncation=True) input_ids = torch.tensor(enc["input_ids"]) attn = torch.tensor(enc["attention_mask"]) labels = input_ids.clone(); labels[attn == 0] = -100 return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} def safe_training_args(**kwargs): return TrainingArguments(**kwargs) class TrainerManager: def __init__(self, syscfg: SystemConfig, loader: CausalLoader): self.cfg = syscfg self.loader = loader def train_causal(self, train_paths: List[str], use_rag: bool = True, use_wandb: bool = True, wandb_project: str = "mahoon-legal-ai", wandb_entity: str = "", run_name: str = "mahoon_causal_lora"): set_seed_all(self.cfg.train.seed) data = read_jsonl_files(train_paths) train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None if rag: rag.init() ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.train.max_seq_len, rag) ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.train.max_seq_len, None) fp16_ok = torch.cuda.is_available() and not bf16_supported() bf16_ok = bf16_supported() if use_wandb: os.environ.setdefault("WANDB_PROJECT", wandb_project or "mahoon-legal-ai") if wandb_entity: os.environ.setdefault("WANDB_ENTITY", wandb_entity) os.environ.pop("WANDB_DISABLED", None) else: os.environ["WANDB_DISABLED"] = "true" args = safe_training_args( output_dir=self.cfg.train.output_dir, num_train_epochs=self.cfg.train.epochs, learning_rate=self.cfg.train.lr, per_device_train_batch_size=self.cfg.train.batch_size, per_device_eval_batch_size=self.cfg.train.batch_size, gradient_accumulation_steps=self.cfg.train.grad_accum, warmup_ratio=self.cfg.train.warmup_ratio, weight_decay=self.cfg.train.weight_decay, evaluation_strategy=self.cfg.train.eval_strategy, save_strategy=self.cfg.train.save_strategy, save_total_limit=self.cfg.train.save_total_limit, load_best_model_at_end=True, metric_for_best_model="eval_loss", logging_steps=self.cfg.train.logging_steps, report_to=(["wandb"] if use_wandb else ["none"]), run_name=run_name, fp16=fp16_ok, bf16=bf16_ok, max_grad_norm=self.cfg.train.max_grad_norm, ) callbacks = [EarlyStoppingCallback(early_stopping_patience=2)] try: if use_wandb: from transformers.integrations import WandbCallback callbacks.append(WandbCallback()) except Exception: pass trainer = Trainer( model=self.loader.model, args=args, train_dataset=ds_tr, eval_dataset=ds_va, tokenizer=self.loader.tokenizer, callbacks=callbacks, ) if use_wandb: try: import wandb wandb.init(project=os.getenv("WANDB_PROJECT", "mahoon-legal-ai"), entity=os.getenv("WANDB_ENTITY"), name=run_name, config={ "base_model": self.loader.model.name_or_path, "epochs": self.cfg.train.epochs, "batch": self.cfg.train.batch_size, "grad_accum": self.cfg.train.grad_accum, "lr": self.cfg.train.lr, "max_seq_len": self.cfg.train.max_seq_len, "use_rag": use_rag, }) except Exception: pass trainer.train() trainer.save_model(self.cfg.train.output_dir) self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) if use_wandb: try: import wandb art = wandb.Artifact("mahoon-model", type="model") art.add_dir(self.cfg.train.output_dir) wandb.log_artifact(art) wandb.finish() except Exception: pass # ========================== # Dataset utilities (Cleaner/Deduper) # ========================== def deduplicate_jsonl(in_path: str, out_path: str, sim_threshold: float = 0.90, text_keys=("input","output")) -> int: rows = [] with open(in_path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if not s: continue try: obj = json.loads(s) except: continue for k in text_keys: if k in obj: obj[k] = normalize_fa(str(obj[k])) rows.append(obj) if not rows: raise RuntimeError("هیچ رکورد معتبری در ورودی نبود.") model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") embs = model.encode([r.get("input","") for r in rows], convert_to_tensor=True, show_progress_bar=False, normalize_embeddings=True) kept, seen = [], torch.zeros(len(rows), dtype=torch.bool) for i in range(len(rows)): if seen[i]: continue sims = st_util.cos_sim(embs[i], embs)[0] dup_idx = (sims >= sim_threshold).nonzero(as_tuple=True)[0].tolist() for j in dup_idx: seen[j] = True kept.append(rows[i]) with open(out_path, "w", encoding="utf-8") as g: for r in kept: g.write(json.dumps(r, ensure_ascii=False) + "\n") return len(kept) # ========================== # App (Gradio) + Role Gating # ========================== class LegalApp: def __init__(self, scfg: Optional[SystemConfig] = None): self.scfg = scfg or SystemConfig() self.rag = LegalRAG(self.scfg.rag) self.loader: Optional[CausalLoader] = None self.gen: Optional[Generator] = None def _file_paths(self, files: List[gr.File]) -> List[str]: paths = [] for f in (files or []): p = getattr(f, "name", None) or getattr(f, "path", None) if p: paths.append(p) return paths # Core (مشاوره/لود آزاد است) def load(self, model_name: str): self.loader = CausalLoader(self.scfg.model).load(model_name) self.gen = Generator(self.loader, self.scfg.model) # RAG msg_rag = "RAG غیرفعال" if self.scfg.rag.enable: try: self.rag = LegalRAG(self.scfg.rag); self.rag.init() msg_rag = "RAG آماده است" except Exception as e: msg_rag = f"RAG خطا: {e}" return f"مدل بارگذاری شد: {model_name}\n{msg_rag}" # --- گیت سمت‌سرور: فقط ادمین --- def build_index(self, laws_file: gr.File, id_key: str, text_key: str, request: gr.Request): if not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." if not self.scfg.rag.enable: return "RAG غیرفعال است." try: self.rag.init() p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None) if not p: return "فایل قوانین معتبر نیست." return self.rag.index_jsonl(p, id_key=id_key, text_key=text_key) except Exception as e: return f"خطا در ایندکس: {e}" def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None, request: gr.Request): if not is_admin(request): return None, "🔒 این عملیات فقط برای مدیران فعال است." try: from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder except Exception as e: return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}" path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None) if not path: return None, "⚠️ فایل ورودی معتبر نیست." try: data = load_json_or_jsonl(path) if max_samples and int(max_samples) > 0: data = data[:int(max_samples)] gb = GoldenBuilder(model_name=model_ckpt) rows = gb.build(data, text_key=text_key, batch_size=int(batch_size)) out_dir = "/tmp/mahoon_datasets"; Path(out_dir).mkdir(parents=True, exist_ok=True) out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl" save_jsonl(rows, out_path) return out_path, f"✅ {len(rows)} رکورد تولید شد." except Exception as e: return None, f"❌ خطا در ساخت دیتاست: {e}" def train(self, model_name: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float, use_wandb: bool, wandb_project: str, wandb_entity: str, run_name: str, progress=gr.Progress(track_tqdm=True), request: gr.Request = None): if not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." progress(0.05, desc="راه‌اندازی") self.scfg.train.epochs = int(epochs) self.scfg.train.batch_size = int(batch) self.scfg.train.lr = float(lr) progress(0.10, desc="بارگذاری مدل/توکنایزر") self.loader = CausalLoader(self.scfg.model).load(model_name) paths = self._file_paths(files) if not paths: return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده." tm = TrainerManager(self.scfg, self.loader) set_seed_all(self.scfg.train.seed) progress(0.30, desc="آماده‌سازی دیتاست‌ها و RAG (اختیاری)") tm.train_causal( paths, use_rag=use_rag, use_wandb=use_wandb, wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name ) progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها") return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد." def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent, request: gr.Request): if not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." p = getattr(f, "name", None) or getattr(f, "path", None) if not p: return "⚠️ فایل داده نامعتبر است." try: from weights_sweep import run_sweep except Exception as e: return f"❌ weights_sweep.py یافت نشد/قابل import نیست: {e}" os.environ.setdefault("WANDB_PROJECT", proj or "mahoon-legal-ai") if ent: os.environ.setdefault("WANDB_ENTITY", ent) try: run_sweep(data_path=p, text_key=tk, max_samples=int(ms), batch_size=int(bs), project=proj, entity=ent, count=int(runs)) return "✅ Sweep اجرا شد. بهترین Run را در W&B بررسی و وزن‌ها را تثبیت کنید." except Exception as e: return f"❌ خطا در اجرای Sweep: {e}" def apply_best_weights(self, wandb_project: str, wandb_entity: str, metric: str = "pass_rate", request: gr.Request = None): if request is not None and not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." try: import wandb, json as _json except Exception as e: return f"❌ W&B در محیط در دسترس نیست: {e}" try: api = wandb.Api() proj_path = f"{wandb_entity}/{wandb_project}" if wandb_entity else wandb_project runs = api.runs(proj_path, filters={"state": "finished"}) except Exception as e: return f"❌ عدم دسترسی به پروژه W&B ({wandb_project}): {e}" best_run = None; best_val = float("-inf") for r in runs: s = r.summary or {} if "weights" in s and metric in s: try: val = float(s[metric]) except Exception: continue if val > best_val: best_val, best_run = val, r if not best_run: return "⚠️ هیچ Run واجد شرایطی با summary['weights'] و متریک موردنظر پیدا نشد." weights = best_run.summary.get("weights", {}) if not isinstance(weights, dict) or not weights: return "⚠️ فرمت وزن‌های بهترین Run نامعتبر است." try: with open("legal_entity_weights.json", "w", encoding="utf-8") as f: _json.dump(weights, f, ensure_ascii=False, indent=2) except Exception as e: return f"❌ خطا در نوشتن legal_entity_weights.json: {e}" rid = getattr(best_run, "id", "unknown") return f"✅ وزن‌ها اعمال شد از Run `{rid}` با {metric}={best_val:.4f}. فایل: `legal_entity_weights.json`" # Consultation (عمومی) def answer(self, question: str, system_prompt: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float): if not question.strip(): return "لطفاً سوال خود را وارد کنید.", "" if not self.gen: return "ابتدا مدل را بارگذاری کنید.", "" self.scfg.model.max_new_tokens = int(max_new_tokens) self.scfg.model.temperature = float(temperature) self.scfg.model.top_p = float(top_p) arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else [] max_refs = 4 if arts: arts = arts[:max_refs] ctx = self.rag.build_context(arts) if arts else "" ans = self.gen.generate(question, ctx, system_prompt) refs = "" if arts: refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a.get('similarity',0):.2f})\n{a['text'][:320]}..." for a in arts]) return ans, refs # UI def build_ui(self): log_deps() try: print("[rag-bootstrap]", ensure_chroma_ready(self.scfg.rag.persist_dir, self.scfg.rag.collection), flush=True) except Exception as e: print("[rag-bootstrap] error:", e, flush=True) default_gen_models = { "Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct", "Llama-3.1-8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", "Mistral-7B Instruct (v0.3)": "mistralai/Mistral-7B-Instruct-v0.3", } with gr.Blocks(title="ماحون — مشاور حقوقی (Causal-only, ZeroGPU)") as app: # بنر نقش role_banner = gr.Markdown() gr.Markdown("""

ماحون — Persian Legal (Causal-only, ZeroGPU)

Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning

""") # --- Tab: Consultation (interactive for all) --- with gr.Tab("مشاوره"): with gr.Row(): gen_model_dd = gr.Dropdown(choices=list(default_gen_models.keys()), value="Qwen2.5-7B Instruct", label="مدل تولید") gen_model_id = gr.Textbox(value=default_gen_models["Qwen2.5-7B Instruct"], label="Model ID (قابل ویرایش)") with gr.Row(): use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟") persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر ChromaDB") collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن") with gr.Row(): top_k = gr.Slider(1, 15, value=self.scfg.rag.top_k, step=1, label="Top-K") threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="آستانه شباهت") load_btn = gr.Button("بارگذاری مدل", variant="primary") status = gr.Textbox(label="وضعیت", interactive=False) with gr.Accordion("پارامترهای تولید", open=False): system_prompt = gr.Textbox(value="You are a helpful Persian legal assistant.", label="System prompt") max_new_tokens = gr.Slider(64, 2048, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens") temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature") top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p") question = gr.Textbox(lines=3, label="سوال حقوقی") gr.Examples( examples=[ ["در صورت نقض قرارداد EPC چه راهکارهای حقوقی دارم؟"], ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"], ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"], ], inputs=question, label="نمونه پرسش‌ها" ) ask_btn = gr.Button("پرسش", variant="primary") answer = gr.Markdown(label="پاسخ"); refs = gr.Markdown(label="مواد قانونی مرتبط") # --- Tab: Indexing (view-only for visitors) --- with gr.Tab("ایندکس قوانین"): gr.Markdown("فایل JSONL قوانین را بارگذاری و ایندکس کنید (کلیدها: `article_id`, `text`).") laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"]) id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده") text_key = gr.Textbox(value="text", label="کلید متن ماده") index_btn = gr.Button("ایندکس‌سازی قوانین"); index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False) index_widgets = [laws_file, id_key, text_key, index_btn] # --- Tab: Dataset Builder (view-only for visitors) --- with gr.Tab("ساخت دیتاست"): gr.Markdown("فایل خام (JSON/JSONL) → خروجی JSONL سازگار با `{input, output}` (از golden_builder).") raw_file = gr.File(label="فایل خام", file_types=[".json",".jsonl"]) with gr.Row(): ds_text_key = gr.Textbox(value="متن_کامل", label="کلید متن (text_key)") model_ckpt = gr.Dropdown( choices=["google/mt5-base", "google/flan-t5-base", "t5-base"], value="google/mt5-base", label="مدل خلاصه‌ساز برای ساخت دیتاست (فقط Builder)" ) with gr.Row(): ds_batch_size = gr.Slider(1, 16, value=4, step=1, label="Batch size") max_samples = gr.Number(value=0, label="حداکثر نمونه (۰=همه)") build_btn = gr.Button("ساخت دیتاست", variant="primary") out_file = gr.File(label="دانلود خروجی JSONL", interactive=False) build_status = gr.Textbox(label="وضعیت", interactive=False) builder_widgets = [raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples, build_btn] # --- Tab: Dataset Cleaning (view-only for visitors) --- with gr.Tab("پاکسازی دیتاست"): gr.Markdown("نرمال‌سازی فارسی + حذف تکراری‌های معنایی (cosine). ورودی: JSONL `{input, output}`.") raw_ds = gr.File(label="JSONL ورودی", file_types=[".jsonl"]) sim_th = gr.Slider(0.80, 0.98, value=0.90, step=0.01, label="آستانه شباهت (cosine)") clean_btn = gr.Button("اجرای پاکسازی", variant="primary") cleaned_out = gr.File(label="دانلود JSONL پاک", interactive=False) clean_status = gr.Markdown() clean_widgets = [raw_ds, sim_th, clean_btn] # --- Tab: Training (view-only for visitors) --- with gr.Tab("آموزش"): gr.Markdown("SFT/LoRA روی مدل‌های causal (فقط `{input, output}`) + W&B logging.") with gr.Row(): model_train_dd = gr.Dropdown( choices=[ "HAKIM (Editable ID below)", "Hooshvareh (Editable ID below)", "Dorna-Llama3-8B", "PersianQA-8B", "Custom (Editable ID below)" ], value="HAKIM (Editable ID below)", label="پروفایل مدل" ) model_train_id = gr.Textbox(value="AI-Hoosh/HAKIM-7B", label="HF Model ID (قابل ویرایش)") use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training") use_wandb = gr.Checkbox(value=True, label="W&B logging فعال باشد؟") wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT") wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)") run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name") gr.Markdown("راهنما: در Settings → Secrets مقدار `WANDB_API_KEY` را تنظیم کنید (مقدار واقعی).") train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"]) with gr.Row(): epochs = gr.Slider(1, 6, value=2, step=1, label="epochs") batch = gr.Slider(1, 8, value=2, step=1, label="batch per device") lr = gr.Number(value=2e-4, label="learning rate") train_btn = gr.Button("شروع آموزش", variant="primary") train_status = gr.Textbox(label="وضعیت آموزش", interactive=False) train_widgets = [model_train_dd, model_train_id, use_rag_train, use_wandb, wandb_project, wandb_entity, run_name, train_files, epochs, batch, lr, train_btn] # --- Tab: Weight Tuning (view-only for visitors) --- with gr.Tab("Weight Tuning"): gr.Markdown("تیون خودکار وزن‌های موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.") tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"]) tune_text_key = gr.Textbox(value="متن_کامل", label="کلید متن") tune_max_samples = gr.Slider(50, 400, value=120, step=10, label="حداکثر نمونه") tune_runs = gr.Slider(4, 64, value=16, step=4, label="تعداد ران Sweep") tune_batch = gr.Slider(1, 4, value=2, step=1, label="batch size Builder") tune_proj = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT") tune_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)") run_tune = gr.Button("شروع Sweep", variant="primary") tune_status = gr.Markdown() gr.Markdown("---") gr.Markdown("اعمال خودکار بهترین وزن‌ها از داشبورد W&B (بر اساس بالاترین `pass_rate`).") metric_dd = gr.Dropdown(choices=["pass_rate"], value="pass_rate", label="متریک انتخاب بهترین Run") apply_btn = gr.Button("اعمال بهترین وزن‌ها از W&B", variant="secondary") tuning_widgets = [tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity, run_tune, metric_dd, apply_btn] # ---- Events (مشاوره آزاد / عملیاتِ ادمینی با گیت) ---- def _resolve_gen(choice: str, override: str) -> str: return override.strip() if override.strip() else default_gen_models[choice] def _on_load(choice, override, rag, pdir, coll, k, th): self.scfg.rag.enable = bool(rag) self.scfg.rag.persist_dir = pdir self.scfg.rag.collection = coll self.scfg.rag.top_k = int(k) self.scfg.rag.similarity_threshold = float(th) return self.load(_resolve_gen(choice, override)) def _whoami(request: gr.Request): u = _get_username(request) or "Visitor" return f"👤 کاربر: **{u}** — دسترسی: {'مدیریتی' if is_admin(request) else 'بازدیدکننده (فقط مشاهده)'}" load_btn.click(_on_load, inputs=[gen_model_dd, gen_model_id, use_rag, persist_dir, collection, top_k, threshold], outputs=status) ask_btn.click(self.answer, inputs=[question, system_prompt, use_rag, max_new_tokens, temperature, top_p], outputs=[answer, refs]) # ادمینی: استفاده از request injection (Gradio به‌طور خودکار تزریق می‌کند) def _index_handler(f, ik, tk, request: gr.Request): return self.build_index(f, ik, tk, request) index_btn.click(_index_handler, inputs=[laws_file, id_key, text_key], outputs=index_status) def _build_ds_handler(rf, tk, ckpt, bs, mx, request: gr.Request): return self.build_dataset(rf, tk, ckpt, bs, mx, request) build_btn.click(_build_ds_handler, inputs=[raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples], outputs=[out_file, build_status]) def _train_handler(prof, mid, files, rg, e, b, l, uw, wp, we, rn, request: gr.Request): def _map_profile_to_id(profile: str, current_id: str) -> str: if current_id.strip(): return current_id.strip() if "Dorna" in profile: return "PartAI/Dorna-Llama3-8B-Instruct" if "PersianQA" in profile: return "zpm/Llama-3.1-PersianQA" if "HAKIM" in profile: return "AI-Hoosh/HAKIM-7B" if "Hooshvareh" in profile: return "HooshvareLab/llama-fa-7b-instruct" return "PartAI/Dorna-Llama3-8B-Instruct" model_id = _map_profile_to_id(prof, mid) return self.train(model_id, files, rg, e, b, l, uw, wp, we, rn, request=request) train_btn.click(_train_handler, inputs=[model_train_dd, model_train_id, train_files, use_rag_train, epochs, batch, lr, use_wandb, wandb_project, wandb_entity, run_name], outputs=train_status) def _clean_handler(f, th): p = getattr(f, "name", None) or getattr(f, "path", None) if not p: return None, "⚠️ فایل نامعتبر." outp = f"/tmp/cleaned_{int(time.time())}.jsonl" n = deduplicate_jsonl(p, outp, sim_threshold=float(th)) return outp, f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{n}**" clean_btn.click(_clean_handler, inputs=[raw_ds, sim_th], outputs=[cleaned_out, clean_status]) def _tune_handler(f, tk, ms, runs, bs, proj, ent, request: gr.Request): return self.run_weight_tune(f, tk, ms, runs, bs, proj, ent, request) run_tune.click(_tune_handler, inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity], outputs=tune_status) def _apply_best_handler(proj, ent, m, request: gr.Request): return self.apply_best_weights(proj, ent, m, request) apply_btn.click(_apply_best_handler, inputs=[tune_proj, tune_entity, metric_dd], outputs=tune_status) # --- Lock non-consultation tabs for visitors on load --- def _gate_all(request: gr.Request): admin = is_admin(request) role_txt = f"👤 کاربر: **{_get_username(request) or 'Visitor'}** — دسترسی: {'مدیریتی' if admin else 'بازدیدکننده (فقط مشاهده)'}" if not admin: lock = gr.update(interactive=False) updates = [lock] * (len(index_widgets) + len(builder_widgets) + len(clean_widgets) + len(train_widgets) + len(tuning_widgets)) else: unlock = gr.update(interactive=True) updates = [unlock] * (len(index_widgets) + len(builder_widgets) + len(clean_widgets) + len(train_widgets) + len(tuning_widgets)) return [role_txt] + updates app.load(_whoami, inputs=None, outputs=role_banner) app.load(_gate_all, inputs=None, outputs=[role_banner] + index_widgets + builder_widgets + clean_widgets + train_widgets + tuning_widgets) return app # ========================== # Entrypoint # ========================== if __name__ == "__main__": app = LegalApp() ui = app.build_ui() try: ui = ui.queue() # پایدار برای ZeroGPU except TypeError: pass ui.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)