# ------------------------------- # app.py (CPU COMPATIBLE VERSION) # # This file contains the backend logic and Gradio UI for the chatbot. # # --- FINAL, WORKING VERSION --- # - Specifies target_modules in LoraConfig to work with the custom Sam2 model. # - Uses a pure PyTorch fine-tuning loop for maximum control and stability. # - Custom Sam2Config inherits from PretrainedConfig to solve subscriptable errors. # - UI polling is backward-compatible with older Gradio versions. # ------------------------------- import time import math import json import requests import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from pathlib import Path from safetensors.torch import load_file from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel import gradio as gr import os from datetime import datetime import threading import time import traceback import spaces # --- RLHF & Training Imports --- from huggingface_hub import HfApi, login from datasets import Dataset, load_dataset, concatenate_datasets from peft import LoraConfig, get_peft_model # ------------------------------- # 0) RLHF & TUNING CONFIGURATION # ------------------------------- FEEDBACK_DATASET_REPO = "Smilyai-labs/Open-Sam-2.5-chat" TUNED_MODEL_REPO_OWNER = "Smilyai-labs" BASE_MODEL_REPO = "Smilyai-labs/Sam-2.5-PRO-SOLVER-V2" FINETUNE_TRIGGER_LIKES = 8 MIN_LIKES_FOR_TRAINING = 2 # --- PyTorch Training Config --- LEARNING_RATE = 2e-4 NUM_EPOCHS = 1 BATCH_SIZE = 1 # --- Login to Hugging Face Hub --- HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: print("WARNING: Hugging Face token not found. Feedback will not be saved and tuning will not run.") else: login(token=HF_TOKEN) print("Hugging Face token found. Feedback logging and model tuning are enabled.") # --- Global state --- LIKE_COUNTER = 0 like_counter_lock = threading.Lock() training_lock = threading.Lock() model_lock = threading.Lock() TRAINING_STATUS = "" # ------------------------------- # 1) Local Sam-2 architecture # ------------------------------- class Sam2Config(PretrainedConfig): model_type = "sam2" def __init__( self, vocab_size=32000, d_model=384, n_layers=6, n_heads=6, ff_mult=4.0, dropout=0.1, input_modality="text", head_type="causal_lm", version="0.1", **kwargs ): self.vocab_size = vocab_size self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.ff_mult = ff_mult self.dropout = dropout self.input_modality = input_modality self.head_type = head_type self.version = version super().__init__(**kwargs) class RMSNorm(nn.Module): def __init__(self, d, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(d)) def forward(self, x): return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() class MHA(nn.Module): def __init__(self, d_model, n_heads, dropout=0.0): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.head_dim = d_model // n_heads self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x, attn_mask=None): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1) scores = scores.masked_fill(causal, float("-inf")) if attn_mask is not None: scores = scores.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2).bool(), float("-inf")) attn = torch.softmax(scores, dim=-1) out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C) return self.out_proj(out) class SwiGLU(nn.Module): def __init__(self, d_model, d_ff, dropout=0.0): super().__init__() self.w1 = nn.Linear(d_model, d_ff, bias=False) self.w2 = nn.Linear(d_model, d_ff, bias=False) self.w3 = nn.Linear(d_ff, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.w3(self.dropout(torch.nn.functional.silu(self.w1(x)) * self.w2(x))) class Block(nn.Module): def __init__(self, d_model, n_heads, ff_mult, dropout=0.0): super().__init__() self.norm1 = RMSNorm(d_model) self.attn = MHA(d_model, n_heads, dropout=dropout) self.norm2 = RMSNorm(d_model) self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout) self.drop = nn.Dropout(dropout) def forward(self, x, attn_mask=None): x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask)) x = x + self.drop(self.ff(self.norm2(x))) return x class Sam2(PreTrainedModel): # <-- CHANGE THIS LINE: inherit from PreTrainedModel config_class = Sam2Config # <-- ADD THIS LINE: tell HF what config class to use def __init__(self, config: Sam2Config): super().__init__(config) # <-- CHANGE THIS LINE: pass config to parent self.config = config # You can keep this if you use it elsewhere self.embed = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.embed.weight def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids} def forward(self, input_ids=None, inputs_embeds=None, attention_mask=None, labels=None, **kwargs): if inputs_embeds is not None: x = inputs_embeds else: if input_ids is None: raise ValueError("You must provide either input_ids or inputs_embeds") x = self.embed(input_ids) for blk in self.blocks: x = blk(x, attn_mask=attention_mask) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if loss is not None: return (loss, logits) return (logits,) # ------------------------------- # 2) Load initial resources # ------------------------------- weights_filename = "model.safetensors" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_REPO) tokenizer.pad_token = tokenizer.eos_token # --- FIXED: Removed extra spaces in URLs --- config_url = f"https://huggingface.co/{BASE_MODEL_REPO}/raw/main/config.json" config_data = requests.get(config_url).json() cfg = Sam2Config(**config_data) # --- FIXED: Removed extra spaces in URLs --- weights_url = f"https://huggingface.co/{BASE_MODEL_REPO}/resolve/main/{weights_filename}" weights_content = requests.get(weights_url).content with open(weights_filename, "wb") as f: f.write(weights_content) model = Sam2(cfg) state_dict = load_file(weights_filename) model.load_state_dict(state_dict) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) model.to(device).eval() print(f"Inference will run on: {device}") EOT_ID = tokenizer.convert_tokens_to_ids("<|eot|>") or tokenizer.eos_token_id SPECIAL_TOKENS = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>"} SYSTEM_PROMPT = "You are Sam-2, a friendly and concise chatbot. Always give short, direct answers and avoid medical or legal advice." AutoModelForCausalLM.register(Sam2Config, Sam2) # ------------------------------- # 3) Inference and Feedback Functions # ------------------------------- @spaces.GPU def sample_next_token( logits, past_tokens, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1, max_repeat=5, no_repeat_ngram_size=3 ): if logits.dim() == 3: logits = logits[:, -1, :].clone() else: logits = logits.clone() batch_size, vocab_size = logits.size(0), logits.size(1) orig_logits = logits.clone() if temperature != 1.0: logits = logits / float(temperature) past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens) for token_id in set(past_list): if 0 <= token_id < vocab_size: logits[:, token_id] /= repetition_penalty if len(past_list) >= max_repeat: last_token, count = past_list[-1], 1 for i in reversed(past_list[:-1]): if i == last_token: count += 1 else: break if count >= max_repeat: logits[:, last_token] = -float("inf") if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size: ngram = tuple(past_list[-no_repeat_ngram_size:]) for token_id in range(vocab_size): if tuple(past_list[-(no_repeat_ngram_size - 1):] + [token_id]) == ngram: logits[:, token_id] = -float("inf") if top_k is not None and top_k > 0: tk = min(max(1, int(top_k)), vocab_size) topk_vals, _ = torch.topk(logits, tk, dim=-1) min_topk = topk_vals[:, -1].unsqueeze(-1) logits[logits < min_topk] = -float("inf") if top_p is not None and 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = torch.cumsum(sorted_probs, dim=-1) for b in range(batch_size): sorted_mask = cumulative_probs[b] > top_p if sorted_mask.numel() > 0: sorted_mask[0] = False tokens_to_remove = sorted_indices[b][sorted_mask] logits[b, tokens_to_remove] = -float("inf") for b in range(batch_size): if torch.isneginf(logits[b]).all(): logits[b] = orig_logits[b] probs = F.softmax(logits, dim=-1) if torch.isnan(probs).any(): probs = torch.ones_like(logits) / logits.size(1) next_token = torch.multinomial(probs, num_samples=1) return next_token.to(device) @spaces.GPU def predict(message, history): chat_history = [] for human, assistant in history: chat_history.append(f"{SPECIAL_TOKENS['user']} {human} {SPECIAL_TOKENS['eot']}") if assistant: chat_history.append(f"{SPECIAL_TOKENS['assistant']} {assistant} {SPECIAL_TOKENS['eot']}") chat_history.append(f"{SPECIAL_TOKENS['user']} {message} {SPECIAL_TOKENS['eot']}") prompt = f"{SPECIAL_TOKENS['system']} {SYSTEM_PROMPT} {SPECIAL_TOKENS['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS['assistant']}" inputs = tokenizer(prompt, return_tensors="pt").to(device) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] generated_text = "" for _ in range(256): with torch.no_grad(), model_lock: outputs = model(input_ids, attention_mask=attention_mask) logits = outputs[0] next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1) token_id = int(next_token.squeeze().item()) if token_id == EOT_ID: break token_str = tokenizer.decode([token_id], skip_special_tokens=True) input_ids = torch.cat([input_ids, next_token], dim=1) attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1) generated_text += token_str yield generated_text def log_feedback(data: gr.LikeData, history: list): global LIKE_COUNTER if not HF_TOKEN: print("Feedback not logged. HF_TOKEN not set.") return feedback_entry = { "prompt": history[data.index[0]][0], "response": data.value, "feedback": 1 if data.liked else 0, "timestamp": datetime.utcnow().isoformat() } new_feedback_dataset = Dataset.from_dict({k: [v] for k, v in feedback_entry.items()}) try: existing_dataset = load_dataset(FEEDBACK_DATASET_REPO, split="train", cache_dir="./cache") combined_dataset = concatenate_datasets([existing_dataset, new_feedback_dataset]) except Exception as e: print(f"Could not load existing dataset: {e}. Creating a new one.") combined_dataset = new_feedback_dataset try: combined_dataset.push_to_hub(FEEDBACK_DATASET_REPO, private=False) feedback_icon = '👍' if data.liked else '👎' print(f"Successfully logged {feedback_icon} feedback. Dataset now has {len(combined_dataset)} entries.") if data.liked: with like_counter_lock: LIKE_COUNTER += 1 current_likes = LIKE_COUNTER print(f"Like recorded. Total likes since start: {current_likes}.") if current_likes > 0 and current_likes % FINETUNE_TRIGGER_LIKES == 0: print(f"--- Like threshold of {FINETUNE_TRIGGER_LIKES} reached! Triggering fine-tuning. ---") tuning_thread = threading.Thread(target=run_tuning_task, daemon=True) tuning_thread.start() except Exception as e: print(f"Error logging feedback to Hub: {e}") # ------------------------------- # 6) Background Fine-Tuning Logic (PyTorch Loop) # ------------------------------- @spaces.GPU(duration=120) def run_tuning_task(): global model, TRAINING_STATUS if not training_lock.acquire(blocking=False): print("Tuning is already in progress. Skipping this trigger.") return print("\n--- Starting PyTorch Fine-Tuning Task ---") try: TRAINING_STATUS = "🔧 Preparing to improve Sam-2.5..." if not HF_TOKEN: TRAINING_STATUS = "Error: HF_TOKEN not set. Cannot run tuning." time.sleep(10) return feedback_data = load_dataset(FEEDBACK_DATASET_REPO, split="train", cache_dir="./cache") liked_data = feedback_data.filter(lambda x: x['feedback'] == 1) print(f"Found {len(liked_data)} total liked responses for training.") # Add shuffle and sample 10,000 random examples liked_data = liked_data.shuffle(seed=42).select(range(5900)) # Use first 10,000 samples if len(liked_data) < MIN_LIKES_FOR_TRAINING: TRAINING_STATUS = f"✅ Improvement complete! (Not enough new data to train, will try again later)." time.sleep(5) return def format_for_training(example): return { "text": f"{SPECIAL_TOKENS['system']} {SYSTEM_PROMPT} {SPECIAL_TOKENS['eot']}\n{SPECIAL_TOKENS['user']} {example['prompt']} {SPECIAL_TOKENS['eot']}\n{SPECIAL_TOKENS['assistant']} {example['response']} {SPECIAL_TOKENS['eot']}"} train_dataset = liked_data.map(format_for_training) print("Loading base model for tuning...") model_to_tune = Sam2(cfg) state_dict_to_tune = load_file(weights_filename) model_to_tune.load_state_dict(state_dict_to_tune) # --- THIS IS THE FIX --- # We explicitly tell PEFT which linear layers in our MHA block to adapt. peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"] ) # --- END FIX --- peft_model = get_peft_model(model_to_tune, peft_config) peft_model.to(device) peft_model.print_trainable_parameters() tokenized_dataset = train_dataset.map(lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512), batched=True) # --- ADDED: Remove the unused 'text' column to clean up the dataset --- tokenized_dataset = tokenized_dataset.remove_columns(["text"]) tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) train_dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE) optimizer = torch.optim.AdamW(peft_model.parameters(), lr=LEARNING_RATE) TRAINING_STATUS = f"🔧 Sam-2.5 is starting training on {len(liked_data)} examples... Thank you all for your contribution to the dataset. The model will train and hot swap shortly.(This can be slow on CPU)" print("Starting model tuning on CPU...") peft_model.train() for epoch in range(NUM_EPOCHS): time.sleep(0.01) for i, batch in enumerate(train_dataloader): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) loss = outputs[0] loss.backward() optimizer.step() optimizer.zero_grad() current_loss = loss.item() print(f"Epoch {epoch+1}, Batch {i+1}/{len(train_dataloader)}, Loss: {current_loss:.4f}") # --- UPDATE UI WITH LIVE LOSS --- TRAINING_STATUS = f"🔧 You are witnessing the training of sam2.5. Training... Batch {i+1}/{len(train_dataloader)}, Loss: {current_loss:.4f}" print("Tuning complete.") TRAINING_STATUS = "✨ Finishing up... Merging improvements." merged_model = peft_model.merge_and_unload() # --- FIXED: Safe Model Swap using model_lock --- with model_lock: print("Hot-swapping live model...") # Create a new instance and copy state, preserving the object reference new_state_dict = merged_model.state_dict() model.load_state_dict(new_state_dict) model.to(device).eval() date_str = datetime.now().strftime("%Y%m%d-%H%M") new_repo_id = f"{TUNED_MODEL_REPO_OWNER}/Sam-2.5-PUBLIC-RLHF-{date_str}" print(f"Saving and uploading tuned model to {new_repo_id}...") # Create a directory to save the model local_dir = f"./{new_repo_id.split('/')[-1]}" os.makedirs(local_dir, exist_ok=True) # Save model using Hugging Face format merged_model.save_pretrained(local_dir, safe_serialization=False) tokenizer.save_pretrained(local_dir) # Push to Hub from huggingface_hub import HfApi api = HfApi() api.create_repo(repo_id=new_repo_id, repo_type="model", exist_ok=True) api.upload_folder( folder_path=local_dir, repo_id=new_repo_id, repo_type="model" ) # Clean up local files import shutil shutil.rmtree(local_dir) print("Upload and hot-swap complete!") TRAINING_STATUS = "✅ Sam-2.5 has been successfully upgraded! Thank you. You have helped shaped the newest generation of sam 2.5 pro solver. You, helped make AI" time.sleep(5) except Exception as e: print(f"An error occurred during the tuning process: {e}") traceback.print_exc() TRAINING_STATUS = f"An error occurred during training: {e}" time.sleep(10) finally: TRAINING_STATUS = "" training_lock.release() print("--- PyTorch Fine-Tuning Task Finished ---") # ------------------------------- # 7) UI Functions & Gradio Interface # ------------------------------- def check_training_status(): global TRAINING_STATUS if TRAINING_STATUS: return gr.update(value=TRAINING_STATUS, visible=True) else: return gr.update(value="", visible=False) def poll_status_updater(): while True: yield check_training_status() time.sleep(1) with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue")) as demo: gr.Markdown(""" # Sam-2.5-PRO-SOLVER-V2 Chat A self-improving chatbot powered by Sam-2. Use the thumb icons to rate responses! The model automatically fine-tunes on your positive feedback and gets smarter live. """) training_status_md = gr.Markdown(value="", visible=False) chatbot = gr.Chatbot(label="Sam-2", bubble_full_width=False) chat_interface = gr.ChatInterface(fn=predict, chatbot=chatbot) chatbot.like(log_feedback, inputs=[chatbot], outputs=None) demo.load(poll_status_updater, None, training_status_md) if __name__ == "__main__": print("Starting Gradio app. Tuning will be triggered by user feedback.") demo.launch(show_api=True)