|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
import math |
|
|
import json |
|
|
import requests |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from pathlib import Path |
|
|
from safetensors.torch import load_file |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel |
|
|
import gradio as gr |
|
|
import os |
|
|
from datetime import datetime |
|
|
import threading |
|
|
import time |
|
|
import traceback |
|
|
import spaces |
|
|
|
|
|
from huggingface_hub import HfApi, login |
|
|
from datasets import Dataset, load_dataset, concatenate_datasets |
|
|
from peft import LoraConfig, get_peft_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FEEDBACK_DATASET_REPO = "Smilyai-labs/Open-Sam-2.5-chat" |
|
|
TUNED_MODEL_REPO_OWNER = "Smilyai-labs" |
|
|
BASE_MODEL_REPO = "Smilyai-labs/Sam-2.5-PRO-SOLVER-V2" |
|
|
FINETUNE_TRIGGER_LIKES = 8 |
|
|
MIN_LIKES_FOR_TRAINING = 2 |
|
|
|
|
|
|
|
|
LEARNING_RATE = 2e-4 |
|
|
NUM_EPOCHS = 1 |
|
|
BATCH_SIZE = 1 |
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
if not HF_TOKEN: |
|
|
print("WARNING: Hugging Face token not found. Feedback will not be saved and tuning will not run.") |
|
|
else: |
|
|
login(token=HF_TOKEN) |
|
|
print("Hugging Face token found. Feedback logging and model tuning are enabled.") |
|
|
|
|
|
|
|
|
LIKE_COUNTER = 0 |
|
|
like_counter_lock = threading.Lock() |
|
|
training_lock = threading.Lock() |
|
|
model_lock = threading.Lock() |
|
|
TRAINING_STATUS = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Sam2Config(PretrainedConfig): |
|
|
model_type = "sam2" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=32000, |
|
|
d_model=384, |
|
|
n_layers=6, |
|
|
n_heads=6, |
|
|
ff_mult=4.0, |
|
|
dropout=0.1, |
|
|
input_modality="text", |
|
|
head_type="causal_lm", |
|
|
version="0.1", |
|
|
**kwargs |
|
|
): |
|
|
self.vocab_size = vocab_size |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self.n_heads = n_heads |
|
|
self.ff_mult = ff_mult |
|
|
self.dropout = dropout |
|
|
self.input_modality = input_modality |
|
|
self.head_type = head_type |
|
|
self.version = version |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, d, eps=1e-6): |
|
|
super().__init__() |
|
|
self.eps = eps |
|
|
self.weight = nn.Parameter(torch.ones(d)) |
|
|
def forward(self, x): |
|
|
return self.weight * x * (x.pow(2).mean(-1, keepdim=True) + self.eps).rsqrt() |
|
|
|
|
|
class MHA(nn.Module): |
|
|
def __init__(self, d_model, n_heads, dropout=0.0): |
|
|
super().__init__() |
|
|
assert d_model % n_heads == 0 |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = d_model // n_heads |
|
|
self.q_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
self.k_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
self.v_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
self.out_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
def forward(self, x, attn_mask=None): |
|
|
B, T, C = x.shape |
|
|
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
|
|
causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1) |
|
|
scores = scores.masked_fill(causal, float("-inf")) |
|
|
if attn_mask is not None: |
|
|
scores = scores.masked_fill(~attn_mask.unsqueeze(1).unsqueeze(2).bool(), float("-inf")) |
|
|
attn = torch.softmax(scores, dim=-1) |
|
|
out = torch.matmul(self.dropout(attn), v).transpose(1, 2).contiguous().view(B, T, C) |
|
|
return self.out_proj(out) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, d_model, d_ff, dropout=0.0): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(d_model, d_ff, bias=False) |
|
|
self.w2 = nn.Linear(d_model, d_ff, bias=False) |
|
|
self.w3 = nn.Linear(d_ff, d_model, bias=False) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
def forward(self, x): |
|
|
return self.w3(self.dropout(torch.nn.functional.silu(self.w1(x)) * self.w2(x))) |
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, d_model, n_heads, ff_mult, dropout=0.0): |
|
|
super().__init__() |
|
|
self.norm1 = RMSNorm(d_model) |
|
|
self.attn = MHA(d_model, n_heads, dropout=dropout) |
|
|
self.norm2 = RMSNorm(d_model) |
|
|
self.ff = SwiGLU(d_model, int(ff_mult * d_model), dropout=dropout) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
def forward(self, x, attn_mask=None): |
|
|
x = x + self.drop(self.attn(self.norm1(x), attn_mask=attn_mask)) |
|
|
x = x + self.drop(self.ff(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
class Sam2(PreTrainedModel): |
|
|
config_class = Sam2Config |
|
|
|
|
|
def __init__(self, config: Sam2Config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.embed = nn.Embedding(config.vocab_size, config.d_model) |
|
|
self.blocks = nn.ModuleList([Block(config.d_model, config.n_heads, config.ff_mult, dropout=config.dropout) for _ in range(config.n_layers)]) |
|
|
self.norm = RMSNorm(config.d_model) |
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
self.lm_head.weight = self.embed.weight |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids} |
|
|
|
|
|
def forward(self, input_ids=None, inputs_embeds=None, attention_mask=None, labels=None, **kwargs): |
|
|
if inputs_embeds is not None: |
|
|
x = inputs_embeds |
|
|
else: |
|
|
if input_ids is None: |
|
|
raise ValueError("You must provide either input_ids or inputs_embeds") |
|
|
x = self.embed(input_ids) |
|
|
|
|
|
for blk in self.blocks: |
|
|
x = blk(x, attn_mask=attention_mask) |
|
|
x = self.norm(x) |
|
|
logits = self.lm_head(x) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
if loss is not None: |
|
|
return (loss, logits) |
|
|
return (logits,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
weights_filename = "model.safetensors" |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_REPO) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
config_url = f"https://huggingface.co/{BASE_MODEL_REPO}/raw/main/config.json" |
|
|
config_data = requests.get(config_url).json() |
|
|
cfg = Sam2Config(**config_data) |
|
|
|
|
|
|
|
|
weights_url = f"https://huggingface.co/{BASE_MODEL_REPO}/resolve/main/{weights_filename}" |
|
|
weights_content = requests.get(weights_url).content |
|
|
with open(weights_filename, "wb") as f: f.write(weights_content) |
|
|
|
|
|
model = Sam2(cfg) |
|
|
state_dict = load_file(weights_filename) |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(device) |
|
|
model.to(device).eval() |
|
|
print(f"Inference will run on: {device}") |
|
|
|
|
|
EOT_ID = tokenizer.convert_tokens_to_ids("<|eot|>") or tokenizer.eos_token_id |
|
|
SPECIAL_TOKENS = {"bos": "<|bos|>", "eot": "<|eot|>", "user": "<|user|>", "assistant": "<|assistant|>", "system": "<|system|>"} |
|
|
SYSTEM_PROMPT = "You are Sam-2, a friendly and concise chatbot. Always give short, direct answers and avoid medical or legal advice." |
|
|
|
|
|
AutoModelForCausalLM.register(Sam2Config, Sam2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def sample_next_token( logits, past_tokens, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.1, max_repeat=5, no_repeat_ngram_size=3 ): |
|
|
if logits.dim() == 3: logits = logits[:, -1, :].clone() |
|
|
else: logits = logits.clone() |
|
|
batch_size, vocab_size = logits.size(0), logits.size(1) |
|
|
orig_logits = logits.clone() |
|
|
if temperature != 1.0: logits = logits / float(temperature) |
|
|
past_list = past_tokens.tolist() if isinstance(past_tokens, torch.Tensor) else list(past_tokens) |
|
|
for token_id in set(past_list): |
|
|
if 0 <= token_id < vocab_size: logits[:, token_id] /= repetition_penalty |
|
|
if len(past_list) >= max_repeat: |
|
|
last_token, count = past_list[-1], 1 |
|
|
for i in reversed(past_list[:-1]): |
|
|
if i == last_token: count += 1 |
|
|
else: break |
|
|
if count >= max_repeat: logits[:, last_token] = -float("inf") |
|
|
if no_repeat_ngram_size > 0 and len(past_list) >= no_repeat_ngram_size: |
|
|
ngram = tuple(past_list[-no_repeat_ngram_size:]) |
|
|
for token_id in range(vocab_size): |
|
|
if tuple(past_list[-(no_repeat_ngram_size - 1):] + [token_id]) == ngram: logits[:, token_id] = -float("inf") |
|
|
if top_k is not None and top_k > 0: |
|
|
tk = min(max(1, int(top_k)), vocab_size) |
|
|
topk_vals, _ = torch.topk(logits, tk, dim=-1) |
|
|
min_topk = topk_vals[:, -1].unsqueeze(-1) |
|
|
logits[logits < min_topk] = -float("inf") |
|
|
if top_p is not None and 0.0 < top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
|
|
sorted_probs = F.softmax(sorted_logits, dim=-1) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
for b in range(batch_size): |
|
|
sorted_mask = cumulative_probs[b] > top_p |
|
|
if sorted_mask.numel() > 0: |
|
|
sorted_mask[0] = False |
|
|
tokens_to_remove = sorted_indices[b][sorted_mask] |
|
|
logits[b, tokens_to_remove] = -float("inf") |
|
|
for b in range(batch_size): |
|
|
if torch.isneginf(logits[b]).all(): logits[b] = orig_logits[b] |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
if torch.isnan(probs).any(): probs = torch.ones_like(logits) / logits.size(1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
return next_token.to(device) |
|
|
@spaces.GPU |
|
|
def predict(message, history): |
|
|
chat_history = [] |
|
|
for human, assistant in history: |
|
|
chat_history.append(f"{SPECIAL_TOKENS['user']} {human} {SPECIAL_TOKENS['eot']}") |
|
|
if assistant: |
|
|
chat_history.append(f"{SPECIAL_TOKENS['assistant']} {assistant} {SPECIAL_TOKENS['eot']}") |
|
|
chat_history.append(f"{SPECIAL_TOKENS['user']} {message} {SPECIAL_TOKENS['eot']}") |
|
|
prompt = f"{SPECIAL_TOKENS['system']} {SYSTEM_PROMPT} {SPECIAL_TOKENS['eot']}\n" + "\n".join(chat_history) + f"\n{SPECIAL_TOKENS['assistant']}" |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
input_ids = inputs["input_ids"] |
|
|
attention_mask = inputs["attention_mask"] |
|
|
generated_text = "" |
|
|
for _ in range(256): |
|
|
with torch.no_grad(), model_lock: |
|
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
|
logits = outputs[0] |
|
|
next_token = sample_next_token(logits, input_ids[0], temperature=0.4, top_k=50, top_p=0.9, repetition_penalty=1.1) |
|
|
token_id = int(next_token.squeeze().item()) |
|
|
if token_id == EOT_ID: break |
|
|
token_str = tokenizer.decode([token_id], skip_special_tokens=True) |
|
|
input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.size(0), 1), device=device, dtype=attention_mask.dtype)], dim=1) |
|
|
generated_text += token_str |
|
|
yield generated_text |
|
|
|
|
|
def log_feedback(data: gr.LikeData, history: list): |
|
|
global LIKE_COUNTER |
|
|
if not HF_TOKEN: |
|
|
print("Feedback not logged. HF_TOKEN not set.") |
|
|
return |
|
|
feedback_entry = { "prompt": history[data.index[0]][0], "response": data.value, "feedback": 1 if data.liked else 0, "timestamp": datetime.utcnow().isoformat() } |
|
|
new_feedback_dataset = Dataset.from_dict({k: [v] for k, v in feedback_entry.items()}) |
|
|
try: |
|
|
existing_dataset = load_dataset(FEEDBACK_DATASET_REPO, split="train", cache_dir="./cache") |
|
|
combined_dataset = concatenate_datasets([existing_dataset, new_feedback_dataset]) |
|
|
except Exception as e: |
|
|
print(f"Could not load existing dataset: {e}. Creating a new one.") |
|
|
combined_dataset = new_feedback_dataset |
|
|
try: |
|
|
combined_dataset.push_to_hub(FEEDBACK_DATASET_REPO, private=False) |
|
|
feedback_icon = '👍' if data.liked else '👎' |
|
|
print(f"Successfully logged {feedback_icon} feedback. Dataset now has {len(combined_dataset)} entries.") |
|
|
if data.liked: |
|
|
with like_counter_lock: |
|
|
LIKE_COUNTER += 1 |
|
|
current_likes = LIKE_COUNTER |
|
|
print(f"Like recorded. Total likes since start: {current_likes}.") |
|
|
if current_likes > 0 and current_likes % FINETUNE_TRIGGER_LIKES == 0: |
|
|
print(f"--- Like threshold of {FINETUNE_TRIGGER_LIKES} reached! Triggering fine-tuning. ---") |
|
|
tuning_thread = threading.Thread(target=run_tuning_task, daemon=True) |
|
|
tuning_thread.start() |
|
|
except Exception as e: |
|
|
print(f"Error logging feedback to Hub: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def run_tuning_task(): |
|
|
global model, TRAINING_STATUS |
|
|
|
|
|
if not training_lock.acquire(blocking=False): |
|
|
print("Tuning is already in progress. Skipping this trigger.") |
|
|
return |
|
|
|
|
|
print("\n--- Starting PyTorch Fine-Tuning Task ---") |
|
|
try: |
|
|
TRAINING_STATUS = "🔧 Preparing to improve Sam-2.5..." |
|
|
|
|
|
if not HF_TOKEN: |
|
|
TRAINING_STATUS = "Error: HF_TOKEN not set. Cannot run tuning." |
|
|
time.sleep(10) |
|
|
return |
|
|
|
|
|
feedback_data = load_dataset(FEEDBACK_DATASET_REPO, split="train", cache_dir="./cache") |
|
|
liked_data = feedback_data.filter(lambda x: x['feedback'] == 1) |
|
|
print(f"Found {len(liked_data)} total liked responses for training.") |
|
|
|
|
|
|
|
|
liked_data = liked_data.shuffle(seed=42).select(range(5900)) |
|
|
|
|
|
if len(liked_data) < MIN_LIKES_FOR_TRAINING: |
|
|
TRAINING_STATUS = f"✅ Improvement complete! (Not enough new data to train, will try again later)." |
|
|
time.sleep(5) |
|
|
return |
|
|
|
|
|
def format_for_training(example): |
|
|
return { "text": f"{SPECIAL_TOKENS['system']} {SYSTEM_PROMPT} {SPECIAL_TOKENS['eot']}\n{SPECIAL_TOKENS['user']} {example['prompt']} {SPECIAL_TOKENS['eot']}\n{SPECIAL_TOKENS['assistant']} {example['response']} {SPECIAL_TOKENS['eot']}"} |
|
|
train_dataset = liked_data.map(format_for_training) |
|
|
|
|
|
print("Loading base model for tuning...") |
|
|
model_to_tune = Sam2(cfg) |
|
|
state_dict_to_tune = load_file(weights_filename) |
|
|
model_to_tune.load_state_dict(state_dict_to_tune) |
|
|
|
|
|
|
|
|
|
|
|
peft_config = LoraConfig( |
|
|
r=16, |
|
|
lora_alpha=32, |
|
|
lora_dropout=0.05, |
|
|
bias="none", |
|
|
task_type="CAUSAL_LM", |
|
|
target_modules=["q_proj", "v_proj"] |
|
|
) |
|
|
|
|
|
|
|
|
peft_model = get_peft_model(model_to_tune, peft_config) |
|
|
peft_model.to(device) |
|
|
peft_model.print_trainable_parameters() |
|
|
|
|
|
tokenized_dataset = train_dataset.map(lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512), batched=True) |
|
|
|
|
|
tokenized_dataset = tokenized_dataset.remove_columns(["text"]) |
|
|
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) |
|
|
train_dataloader = DataLoader(tokenized_dataset, batch_size=BATCH_SIZE) |
|
|
|
|
|
optimizer = torch.optim.AdamW(peft_model.parameters(), lr=LEARNING_RATE) |
|
|
|
|
|
TRAINING_STATUS = f"🔧 Sam-2.5 is starting training on {len(liked_data)} examples... Thank you all for your contribution to the dataset. The model will train and hot swap shortly.(This can be slow on CPU)" |
|
|
print("Starting model tuning on CPU...") |
|
|
peft_model.train() |
|
|
for epoch in range(NUM_EPOCHS): |
|
|
time.sleep(0.01) |
|
|
for i, batch in enumerate(train_dataloader): |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) |
|
|
loss = outputs[0] |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
current_loss = loss.item() |
|
|
print(f"Epoch {epoch+1}, Batch {i+1}/{len(train_dataloader)}, Loss: {current_loss:.4f}") |
|
|
|
|
|
TRAINING_STATUS = f"🔧 You are witnessing the training of sam2.5. Training... Batch {i+1}/{len(train_dataloader)}, Loss: {current_loss:.4f}" |
|
|
|
|
|
print("Tuning complete.") |
|
|
|
|
|
TRAINING_STATUS = "✨ Finishing up... Merging improvements." |
|
|
merged_model = peft_model.merge_and_unload() |
|
|
|
|
|
|
|
|
with model_lock: |
|
|
print("Hot-swapping live model...") |
|
|
|
|
|
new_state_dict = merged_model.state_dict() |
|
|
model.load_state_dict(new_state_dict) |
|
|
model.to(device).eval() |
|
|
|
|
|
date_str = datetime.now().strftime("%Y%m%d-%H%M") |
|
|
new_repo_id = f"{TUNED_MODEL_REPO_OWNER}/Sam-2.5-PUBLIC-RLHF-{date_str}" |
|
|
|
|
|
print(f"Saving and uploading tuned model to {new_repo_id}...") |
|
|
|
|
|
|
|
|
local_dir = f"./{new_repo_id.split('/')[-1]}" |
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
merged_model.save_pretrained(local_dir, safe_serialization=False) |
|
|
tokenizer.save_pretrained(local_dir) |
|
|
|
|
|
|
|
|
from huggingface_hub import HfApi |
|
|
api = HfApi() |
|
|
api.create_repo(repo_id=new_repo_id, repo_type="model", exist_ok=True) |
|
|
api.upload_folder( |
|
|
folder_path=local_dir, |
|
|
repo_id=new_repo_id, |
|
|
repo_type="model" |
|
|
) |
|
|
|
|
|
|
|
|
import shutil |
|
|
shutil.rmtree(local_dir) |
|
|
|
|
|
print("Upload and hot-swap complete!") |
|
|
TRAINING_STATUS = "✅ Sam-2.5 has been successfully upgraded! Thank you. You have helped shaped the newest generation of sam 2.5 pro solver. You, helped make AI" |
|
|
time.sleep(5) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"An error occurred during the tuning process: {e}") |
|
|
traceback.print_exc() |
|
|
TRAINING_STATUS = f"An error occurred during training: {e}" |
|
|
time.sleep(10) |
|
|
finally: |
|
|
TRAINING_STATUS = "" |
|
|
training_lock.release() |
|
|
print("--- PyTorch Fine-Tuning Task Finished ---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_training_status(): |
|
|
global TRAINING_STATUS |
|
|
if TRAINING_STATUS: |
|
|
return gr.update(value=TRAINING_STATUS, visible=True) |
|
|
else: |
|
|
return gr.update(value="", visible=False) |
|
|
|
|
|
|
|
|
def poll_status_updater(): |
|
|
while True: |
|
|
yield check_training_status() |
|
|
time.sleep(1) |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="blue")) as demo: |
|
|
gr.Markdown(""" |
|
|
# Sam-2.5-PRO-SOLVER-V2 Chat |
|
|
A self-improving chatbot powered by Sam-2. Use the thumb icons to rate responses! |
|
|
The model automatically fine-tunes on your positive feedback and gets smarter live. |
|
|
""") |
|
|
|
|
|
training_status_md = gr.Markdown(value="", visible=False) |
|
|
chatbot = gr.Chatbot(label="Sam-2", bubble_full_width=False) |
|
|
chat_interface = gr.ChatInterface(fn=predict, chatbot=chatbot) |
|
|
chatbot.like(log_feedback, inputs=[chatbot], outputs=None) |
|
|
|
|
|
demo.load(poll_status_updater, None, training_status_md) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Starting Gradio app. Tuning will be triggered by user feedback.") |
|
|
demo.launch(show_api=True) |