# -*- coding: utf-8 -*- """ Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B + ZeroGPU + Role Gating - تب «مشاوره» برای همه تعاملی است. - تبهای «ایندکس»، «ساخت دیتاست»، «پاکسازی»، «آموزش»، «Weight Tuning» برای بازدیدکننده فقط نمایشیاند؛ و سمتسرور نیز گِیت نقش دارد (ادمین/بازدیدکننده). پیشنیازها: - golden_builder.py , weights_sweep.py - Settings → Secrets: WANDB_API_KEY (در صورت استفاده از W&B) - Settings → Environment Variables: ADMIN_USERS (مثلاً: haji-mammad, teammate1) - requirements.txt (ZeroGPU-ready) شامل spaces>=0.42.0 """ from __future__ import annotations # --- Telemetry hard-off + ZeroGPU SDK (must be before chroma import) --- import os, logging os.environ["CHROMA_TELEMETRY_ENABLED"] = "false" os.environ["ANONYMIZED_TELEMETRY"] = "false" import spaces # ZeroGPU SDK # (اختیاری) کاهش نویز لاگها logging.getLogger("chromadb").setLevel(logging.ERROR) logging.getLogger("posthog").setLevel(logging.CRITICAL) # ----------------------------------------------------------------------- import sys, re, json, time, pickle, zipfile, warnings from dataclasses import dataclass, field from pathlib import Path from typing import List, Dict, Optional import numpy as np import torch from torch.utils.data import Dataset from sklearn.model_selection import train_test_split import gradio as gr warnings.filterwarnings("ignore") # ====== Transformers ====== import transformers as tf from transformers import ( AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, EarlyStoppingCallback ) # ====== RAG stack ====== import chromadb from chromadb.config import Settings from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util # ---- Monkeypatch Chroma telemetry (fallback) ---- try: import chromadb.telemetry as _ctel try: _ctel.client = None except Exception: pass for _n in ("capture", "capture_event"): if hasattr(_ctel, _n): try: setattr(_ctel, _n, lambda *a, **k: None) except Exception: pass if hasattr(_ctel, "Telemetry"): try: _ctel.Telemetry().capture = lambda *a, **k: None except Exception: pass except Exception: pass # ------------------------------------------------- # ========= Persian normalization ========= ZWNJ = "\u200c" AR_DIGITS = "٠١٢٣٤٥٦٧٨٩" FA_DIGITS = "۰۱۲۳۴۵۶۷۸۹" EN_DIGITS = "0123456789" def normalize_fa(s: str) -> str: if not s: return s s = s.replace("\u064A", "ی").replace("\u0643", "ک") s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)} s = s.translate(trans) s = re.sub(r"\s*\s*", ZWNJ, s) s = re.sub(r"\s+", " ", s).strip() return s # ========================== # Configs # ========================== @dataclass class ModelConfig: model_name: str = "Qwen/Qwen2.5-7B-Instruct" max_input_length: int = 3072 max_new_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 do_sample: bool = True gradient_checkpointing: bool = True @dataclass class RAGConfig: persist_dir: str = "./chroma_db" collection: str = "legal_articles" top_k: int = 6 similarity_threshold: float = 0.68 context_char_limit: int = 260 enable: bool = True reranker_name: str = "Alibaba-NLP/gte-multilingual-reranker-base" @dataclass class TrainConfig: base_model: str = "PartAI/Dorna-Llama3-8B-Instruct" alt_model_1: str = "zpm/Llama-3.1-PersianQA" hakim_model: str = "AI-Hoosh/HAKIM-7B" hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct" output_dir: str = "./mahoon_causal_lora" seed: int = 42 test_size: float = 0.1 epochs: int = 2 batch_size: int = 2 grad_accum: int = 4 lr: float = 2e-4 warmup_ratio: float = 0.03 weight_decay: float = 0.0 logging_steps: int = 50 eval_strategy: str = "epoch" save_strategy: str = "epoch" save_total_limit: int = 2 report_to: str = "wandb" max_grad_norm: float = 1.0 use_4bit: bool = False max_seq_len: int = 2048 @dataclass class SystemConfig: model: ModelConfig = field(default_factory=ModelConfig) rag: RAGConfig = field(default_factory=RAGConfig) train: TrainConfig = field(default_factory=TrainConfig) # ========================== # Helpers # ========================== def set_seed_all(seed: int = 42): import random random.seed(seed); np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def bf16_supported(): return torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)() def log_deps(): try: import accelerate, datasets print("[deps]", f"python={sys.version.split()[0]}", f"transformers={tf.__version__}", f"accelerate={accelerate.__version__}", f"datasets={datasets.__version__}", f"gradio={gr.__version__}", flush=True) except Exception as e: print("[deps] warn:", e, flush=True) # ========================== # Role gating helpers # ========================== def _get_username(request: gr.Request) -> str | None: try: return getattr(request, "username", None) except Exception: return None def is_admin(request: gr.Request) -> bool: uname = _get_username(request) if not uname: return False author = os.getenv("SPACE_AUTHOR_NAME", "").strip() allow = {u.strip() for u in os.getenv("ADMIN_USERS", "").split(",") if u.strip()} return (uname == author) or (uname in allow) # ========================== # RAG: Chroma + BM25 + CrossEncoder reranker # ========================== class LegalRAG: def __init__(self, cfg: RAGConfig): self.cfg = cfg self.client = None self.collection = None self.reranker: Optional[CrossEncoder] = None self.bm25 = None self.bm25_ids: List[str] = [] self.bm25_path = str(Path(self.cfg.persist_dir) / "bm25.pkl") def init(self): Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True) self.client = chromadb.PersistentClient( path=self.cfg.persist_dir, settings=Settings(anonymized_telemetry=False) ) try: self.collection = self.client.get_or_create_collection(self.cfg.collection) except Exception: try: self.collection = self.client.get_collection(self.cfg.collection) except Exception: self.collection = self.client.create_collection(self.cfg.collection) try: self.reranker = CrossEncoder(self.cfg.reranker_name, device="cpu") except Exception: self.reranker = None if Path(self.bm25_path).exists(): with open(self.bm25_path, "rb") as f: obj = pickle.load(f) self.bm25 = obj["bm25"]; self.bm25_ids = obj["ids"] def _rebuild_bm25(self, ids: List[str], docs: List[str]): corpus = [normalize_fa(d).split() for d in docs] self.bm25 = BM25Okapi(corpus) self.bm25_ids = ids with open(self.bm25_path, "wb") as f: pickle.dump({"bm25": self.bm25, "ids": self.bm25_ids}, f) def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"): if not self.collection: self.init() seen: Dict[str, int] = {} ids, docs, metas = [], [], [] def _norm_id(x: str) -> str: x = x or "" x = x.replace("\u064A", "ی").replace("\u0643", "ک") trans = {ord(a): e for a, e in zip("٠١٢٣٤٥٦٧٨٩۰۱۲۳۴۵۶۷۸۹", "01234567890123456789")} x = x.translate(trans) x = re.sub(r"\s+", "", x) return x with open(jsonl_path, "r", encoding="utf-8") as f: for i, line in enumerate(f): s = line.strip() if not s: continue try: obj = json.loads(s) except: continue raw_id = str(obj.get(id_key, f"auto_{i}")) base_id = _norm_id(raw_id) txt = normalize_fa(str(obj.get(text_key, "")).strip()) if not txt: continue if base_id in seen: seen[base_id] += 1 uid = f"{base_id}__d{seen[base_id]}" dupe_idx = seen[base_id] else: seen[base_id] = 1 uid = base_id dupe_idx = 1 ids.append(uid); docs.append(txt); metas.append({"article_id": base_id, "dupe_idx": dupe_idx}) if not ids: return "هیچ سندی برای ایندکس یافت نشد." self.collection.upsert(ids=ids, documents=docs, metadatas=metas) self._rebuild_bm25(ids, docs) dup_count = sum(1 for _, c in seen.items() if c > 1) return f"✅ {len(ids)} سند ایندکس شد (Dense+BM25). شناسههای تکراری: {dup_count} کلید (با پسوند __dN یکتا شدند)." def retrieve(self, query: str) -> List[Dict]: if not self.collection: return [] qn = normalize_fa(query) # Dense try: res = self.collection.query( query_texts=[qn], n_results=max(self.cfg.top_k * 3, 20), include=["documents", "metadatas", "distances"], ) out = [] docs = res.get("documents", [[]])[0] metas = res.get("metadatas", [[]])[0] dists = res.get("distances", [[1.0]])[0] for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)): sim = 1.0 - float(dist) out.append({"article_id": (meta or {}).get("article_id", f"unk_{i}"), "text": doc, "similarity": sim}) except Exception: out = [] # BM25 bm25_hits = [] if self.bm25 is not None and self.bm25_ids: scores = self.bm25.get_scores(normalize_fa(qn).split()) idxs = np.argsort(scores)[::-1][:max(self.cfg.top_k * 3, 20)] smax = float(scores.max() + 1e-8) for j in idxs: aid = self.bm25_ids[int(j)] try: got = self.collection.get(ids=[aid]) tdoc = got["documents"][0] except Exception: tdoc = "" bm25_hits.append({"article_id": aid, "text": tdoc, "similarity": float(scores[j]) / smax}) # merge pool: Dict[str, Dict] = {} for a in out + bm25_hits: if a["article_id"] not in pool or a.get("similarity", 0) > pool[a["article_id"]].get("similarity", 0): pool[a["article_id"]] = a merged = [a for a in pool.values() if a.get("text") and len(a["text"]) > 15] merged = [a for a in merged if a.get("similarity", 0) >= self.cfg.similarity_threshold] # rerank (GPU only during predict) if merged and self.reranker: pairs = [(qn, a["text"]) for a in merged] try: with spaces.GPU(duration=30): scores = self.reranker.predict(pairs) except Exception: scores = self.reranker.predict(pairs) for a, s in zip(merged, scores): a["score"] = float(s) merged = sorted(merged, key=lambda x: x.get("score", 0), reverse=True)[: self.cfg.top_k] else: merged = sorted(merged, key=lambda x: x.get("similarity", 0), reverse=True)[: self.cfg.top_k] return merged def build_context(self, arts: List[Dict]) -> str: if not arts: return "" bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts] return "مواد مرتبط:\n" + "\n".join(bullets) # ========= RAG bootstrap from repo ========= def parse_law_textfile_to_jsonl(txt_path: str, out_jsonl: str): pat = re.compile(r"(?:ماده|مادّه)\s+(\d+)\s*[:\-–]\s*(.+)") rows = [] with open(txt_path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if not s: continue m = pat.match(s) if not m: continue aid = m.group(1); body = m.group(2).strip() if len(body) < 12: continue rows.append({"article_id": aid, "text": normalize_fa(body)}) if not rows: raise RuntimeError("هیچ مادهای با الگوی تعریفشده پیدا نشد.") with open(out_jsonl, "w", encoding="utf-8") as g: for r in rows: g.write(json.dumps(r, ensure_ascii=False) + "\n") return len(rows) def ensure_chroma_ready(persist_dir="./chroma_db", collection="legal_articles") -> str: Path(persist_dir).mkdir(parents=True, exist_ok=True) if any(Path(persist_dir).glob("*")): return f"ChromaDB موجود است." zip_path = Path("./chroma_legal_db.zip") if zip_path.exists(): try: with zipfile.ZipFile(zip_path, "r") as z: z.extractall(persist_dir) return "ChromaDB از zip بازیابی شد." except Exception: pass txt_path = Path("./all_legal_sentences.txt") if txt_path.exists(): n = parse_law_textfile_to_jsonl(str(txt_path), "./laws.jsonl") rag_local = LegalRAG(RAGConfig(persist_dir=persist_dir, collection=collection)) rag_local.init() msg = rag_local.index_jsonl("./laws.jsonl", id_key="article_id", text_key="text") return f"از متن خام {n} رکورد استخراج شد. {msg}" return "پایگاه RAG موجود نیست و منبع خامی هم برای ساخت پیدا نشد." # ========================== # Loader + Generator (Causal-only, ZeroGPU) # ========================== class CausalLoader: def __init__(self, mcfg: ModelConfig): self.cfg = mcfg self.tokenizer = None self.model = None def load(self, model_name: str): self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"): self.tokenizer.pad_token = self.tokenizer.eos_token try: with spaces.GPU(duration=90): kwargs = {"low_cpu_mem_usage": True} if torch.cuda.is_available(): kwargs["device_map"] = "auto" kwargs["torch_dtype"] = torch.bfloat16 if bf16_supported() else torch.float16 self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"): try: self.model.gradient_checkpointing_enable() except Exception: pass except Exception: self.model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True) return self class Generator: def __init__(self, loader: CausalLoader, mcfg: ModelConfig): self.tk = loader.tokenizer self.model = loader.model self.cfg = mcfg def generate(self, question: str, context: str = "", system_prompt: str = "You are a helpful Persian legal assistant.") -> str: parts = [] if system_prompt: parts.append(f"<|system|>\n{system_prompt}") if context: parts.append(f"<|system|>\nاز منابع زیر استفاده کن و استنادی پاسخ بده:\n{context}") parts.append(f"<|user|>\n{question}") prompt = "\n".join(parts) + "\n<|assistant|>\n" enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length) try: with spaces.GPU(duration=60): dev_model = next(self.model.parameters()).device if hasattr(self.model, "parameters") else "cpu" inputs = {k: v.to(dev_model) for k, v in enc.items()} with torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=self.cfg.max_new_tokens, do_sample=self.cfg.do_sample, temperature=self.cfg.temperature, top_p=self.cfg.top_p, pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, ) except Exception: inputs = {k: v for k, v in enc.items()} with torch.no_grad(): out = self.model.generate( **inputs, max_new_tokens=min(self.cfg.max_new_tokens, 256), do_sample=self.cfg.do_sample, temperature=self.cfg.temperature, top_p=self.cfg.top_p, pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, ) return self.tk.decode(out[0], skip_special_tokens=True) # ========================== # Datasets & Trainer (Causal-only, W&B) # ========================== def read_jsonl_files(paths: List[str]) -> List[Dict]: data: List[Dict] = [] for p in paths: if not p: continue with open(p, 'r', encoding='utf-8') as f: for line in f: s = line.strip() if not s: continue try: data.append(json.loads(s)) except json.JSONDecodeError: continue return data class CausalJSONLDataset(Dataset): def __init__(self, data: List[Dict], tokenizer, max_len: int, rag: Optional[LegalRAG] = None, enhance_every:int = 8): self.tk = tokenizer self.max_len = max_len self.items = [] for i, ex in enumerate(data): src = normalize_fa(str(ex.get("input", "")).strip()) tgt = normalize_fa(str(ex.get("output", "")).strip()) if not src or not tgt: continue ctx = "" if rag and i % enhance_every == 0: arts = rag.retrieve(src) ctx = rag.build_context(arts) text = "" if ctx: text += f"<|system|>\nاز منابع زیر استفاده کن:\n{ctx}\n" text += f"<|system|>\nYou are a helpful Persian legal assistant.\n" text += f"<|user|>\n{src}\n<|assistant|>\n{tgt}" self.items.append(text) def __len__(self): return len(self.items) def __getitem__(self, idx): text = self.items[idx] enc = self.tk(text, max_length=self.max_len, padding="max_length", truncation=True) input_ids = torch.tensor(enc["input_ids"]) attn = torch.tensor(enc["attention_mask"]) labels = input_ids.clone(); labels[attn == 0] = -100 return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} def safe_training_args(**kwargs): return TrainingArguments(**kwargs) class TrainerManager: def __init__(self, syscfg: SystemConfig, loader: CausalLoader): self.cfg = syscfg self.loader = loader def train_causal(self, train_paths: List[str], use_rag: bool = True, use_wandb: bool = True, wandb_project: str = "mahoon-legal-ai", wandb_entity: str = "", run_name: str = "mahoon_causal_lora"): set_seed_all(self.cfg.train.seed) data = read_jsonl_files(train_paths) train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None if rag: rag.init() ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.train.max_seq_len, rag) ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.train.max_seq_len, None) fp16_ok = torch.cuda.is_available() and not bf16_supported() bf16_ok = bf16_supported() if use_wandb: os.environ.setdefault("WANDB_PROJECT", wandb_project or "mahoon-legal-ai") if wandb_entity: os.environ.setdefault("WANDB_ENTITY", wandb_entity) os.environ.pop("WANDB_DISABLED", None) else: os.environ["WANDB_DISABLED"] = "true" args = safe_training_args( output_dir=self.cfg.train.output_dir, num_train_epochs=self.cfg.train.epochs, learning_rate=self.cfg.train.lr, per_device_train_batch_size=self.cfg.train.batch_size, per_device_eval_batch_size=self.cfg.train.batch_size, gradient_accumulation_steps=self.cfg.train.grad_accum, warmup_ratio=self.cfg.train.warmup_ratio, weight_decay=self.cfg.train.weight_decay, evaluation_strategy=self.cfg.train.eval_strategy, save_strategy=self.cfg.train.save_strategy, save_total_limit=self.cfg.train.save_total_limit, load_best_model_at_end=True, metric_for_best_model="eval_loss", logging_steps=self.cfg.train.logging_steps, report_to=(["wandb"] if use_wandb else ["none"]), run_name=run_name, fp16=fp16_ok, bf16=bf16_ok, max_grad_norm=self.cfg.train.max_grad_norm, ) callbacks = [EarlyStoppingCallback(early_stopping_patience=2)] try: if use_wandb: from transformers.integrations import WandbCallback callbacks.append(WandbCallback()) except Exception: pass trainer = Trainer( model=self.loader.model, args=args, train_dataset=ds_tr, eval_dataset=ds_va, tokenizer=self.loader.tokenizer, callbacks=callbacks, ) if use_wandb: try: import wandb wandb.init(project=os.getenv("WANDB_PROJECT", "mahoon-legal-ai"), entity=os.getenv("WANDB_ENTITY"), name=run_name, config={ "base_model": self.loader.model.name_or_path, "epochs": self.cfg.train.epochs, "batch": self.cfg.train.batch_size, "grad_accum": self.cfg.train.grad_accum, "lr": self.cfg.train.lr, "max_seq_len": self.cfg.train.max_seq_len, "use_rag": use_rag, }) except Exception: pass trainer.train() trainer.save_model(self.cfg.train.output_dir) self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) if use_wandb: try: import wandb art = wandb.Artifact("mahoon-model", type="model") art.add_dir(self.cfg.train.output_dir) wandb.log_artifact(art) wandb.finish() except Exception: pass # ========================== # Dataset utilities (Cleaner/Deduper) # ========================== def deduplicate_jsonl(in_path: str, out_path: str, sim_threshold: float = 0.90, text_keys=("input","output")) -> int: rows = [] with open(in_path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if not s: continue try: obj = json.loads(s) except: continue for k in text_keys: if k in obj: obj[k] = normalize_fa(str(obj[k])) rows.append(obj) if not rows: raise RuntimeError("هیچ رکورد معتبری در ورودی نبود.") model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") embs = model.encode([r.get("input","") for r in rows], convert_to_tensor=True, show_progress_bar=False, normalize_embeddings=True) kept, seen = [], torch.zeros(len(rows), dtype=torch.bool) for i in range(len(rows)): if seen[i]: continue sims = st_util.cos_sim(embs[i], embs)[0] dup_idx = (sims >= sim_threshold).nonzero(as_tuple=True)[0].tolist() for j in dup_idx: seen[j] = True kept.append(rows[i]) with open(out_path, "w", encoding="utf-8") as g: for r in kept: g.write(json.dumps(r, ensure_ascii=False) + "\n") return len(kept) # ========================== # App (Gradio) + Role Gating # ========================== class LegalApp: def __init__(self, scfg: Optional[SystemConfig] = None): self.scfg = scfg or SystemConfig() self.rag = LegalRAG(self.scfg.rag) self.loader: Optional[CausalLoader] = None self.gen: Optional[Generator] = None def _file_paths(self, files: List[gr.File]) -> List[str]: paths = [] for f in (files or []): p = getattr(f, "name", None) or getattr(f, "path", None) if p: paths.append(p) return paths # Core (مشاوره/لود آزاد است) def load(self, model_name: str): self.loader = CausalLoader(self.scfg.model).load(model_name) self.gen = Generator(self.loader, self.scfg.model) # RAG msg_rag = "RAG غیرفعال" if self.scfg.rag.enable: try: self.rag = LegalRAG(self.scfg.rag); self.rag.init() msg_rag = "RAG آماده است" except Exception as e: msg_rag = f"RAG خطا: {e}" return f"مدل بارگذاری شد: {model_name}\n{msg_rag}" # --- گیت سمتسرور: فقط ادمین --- def build_index(self, laws_file: gr.File, id_key: str, text_key: str, request: gr.Request): if not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." if not self.scfg.rag.enable: return "RAG غیرفعال است." try: self.rag.init() p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None) if not p: return "فایل قوانین معتبر نیست." return self.rag.index_jsonl(p, id_key=id_key, text_key=text_key) except Exception as e: return f"خطا در ایندکس: {e}" def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None, request: gr.Request): if not is_admin(request): return None, "🔒 این عملیات فقط برای مدیران فعال است." try: from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder except Exception as e: return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}" path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None) if not path: return None, "⚠️ فایل ورودی معتبر نیست." try: data = load_json_or_jsonl(path) if max_samples and int(max_samples) > 0: data = data[:int(max_samples)] gb = GoldenBuilder(model_name=model_ckpt) rows = gb.build(data, text_key=text_key, batch_size=int(batch_size)) out_dir = "/tmp/mahoon_datasets"; Path(out_dir).mkdir(parents=True, exist_ok=True) out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl" save_jsonl(rows, out_path) return out_path, f"✅ {len(rows)} رکورد تولید شد." except Exception as e: return None, f"❌ خطا در ساخت دیتاست: {e}" def train(self, model_name: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float, use_wandb: bool, wandb_project: str, wandb_entity: str, run_name: str, progress=gr.Progress(track_tqdm=True), request: gr.Request = None): if not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." progress(0.05, desc="راهاندازی") self.scfg.train.epochs = int(epochs) self.scfg.train.batch_size = int(batch) self.scfg.train.lr = float(lr) progress(0.10, desc="بارگذاری مدل/توکنایزر") self.loader = CausalLoader(self.scfg.model).load(model_name) paths = self._file_paths(files) if not paths: return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده." tm = TrainerManager(self.scfg, self.loader) set_seed_all(self.scfg.train.seed) progress(0.30, desc="آمادهسازی دیتاستها و RAG (اختیاری)") tm.train_causal( paths, use_rag=use_rag, use_wandb=use_wandb, wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name ) progress(0.95, desc="ذخیرهٔ آرتیفکتها") return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد." def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent, request: gr.Request): if not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." p = getattr(f, "name", None) or getattr(f, "path", None) if not p: return "⚠️ فایل داده نامعتبر است." try: from weights_sweep import run_sweep except Exception as e: return f"❌ weights_sweep.py یافت نشد/قابل import نیست: {e}" os.environ.setdefault("WANDB_PROJECT", proj or "mahoon-legal-ai") if ent: os.environ.setdefault("WANDB_ENTITY", ent) try: run_sweep(data_path=p, text_key=tk, max_samples=int(ms), batch_size=int(bs), project=proj, entity=ent, count=int(runs)) return "✅ Sweep اجرا شد. بهترین Run را در W&B بررسی و وزنها را تثبیت کنید." except Exception as e: return f"❌ خطا در اجرای Sweep: {e}" def apply_best_weights(self, wandb_project: str, wandb_entity: str, metric: str = "pass_rate", request: gr.Request = None): if request is not None and not is_admin(request): return "🔒 این عملیات فقط برای مدیران فعال است." try: import wandb, json as _json except Exception as e: return f"❌ W&B در محیط در دسترس نیست: {e}" try: api = wandb.Api() proj_path = f"{wandb_entity}/{wandb_project}" if wandb_entity else wandb_project runs = api.runs(proj_path, filters={"state": "finished"}) except Exception as e: return f"❌ عدم دسترسی به پروژه W&B ({wandb_project}): {e}" best_run = None; best_val = float("-inf") for r in runs: s = r.summary or {} if "weights" in s and metric in s: try: val = float(s[metric]) except Exception: continue if val > best_val: best_val, best_run = val, r if not best_run: return "⚠️ هیچ Run واجد شرایطی با summary['weights'] و متریک موردنظر پیدا نشد." weights = best_run.summary.get("weights", {}) if not isinstance(weights, dict) or not weights: return "⚠️ فرمت وزنهای بهترین Run نامعتبر است." try: with open("legal_entity_weights.json", "w", encoding="utf-8") as f: _json.dump(weights, f, ensure_ascii=False, indent=2) except Exception as e: return f"❌ خطا در نوشتن legal_entity_weights.json: {e}" rid = getattr(best_run, "id", "unknown") return f"✅ وزنها اعمال شد از Run `{rid}` با {metric}={best_val:.4f}. فایل: `legal_entity_weights.json`" # Consultation (عمومی) def answer(self, question: str, system_prompt: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float): if not question.strip(): return "لطفاً سوال خود را وارد کنید.", "" if not self.gen: return "ابتدا مدل را بارگذاری کنید.", "" self.scfg.model.max_new_tokens = int(max_new_tokens) self.scfg.model.temperature = float(temperature) self.scfg.model.top_p = float(top_p) arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else [] max_refs = 4 if arts: arts = arts[:max_refs] ctx = self.rag.build_context(arts) if arts else "" ans = self.gen.generate(question, ctx, system_prompt) refs = "" if arts: refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a.get('similarity',0):.2f})\n{a['text'][:320]}..." for a in arts]) return ans, refs # UI def build_ui(self): log_deps() try: print("[rag-bootstrap]", ensure_chroma_ready(self.scfg.rag.persist_dir, self.scfg.rag.collection), flush=True) except Exception as e: print("[rag-bootstrap] error:", e, flush=True) default_gen_models = { "Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct", "Llama-3.1-8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", "Mistral-7B Instruct (v0.3)": "mistralai/Mistral-7B-Instruct-v0.3", } with gr.Blocks(title="ماحون — مشاور حقوقی (Causal-only, ZeroGPU)") as app: # بنر نقش role_banner = gr.Markdown() gr.Markdown("""
Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning