sonar-llm-1.3b / sonarllm_model /sonar_llm_model.py
raxtemur's picture
Initial upload (weights + code + README)
3da37ca verified
import time
from dataclasses import dataclass
from typing import List, Optional, Tuple
import nltk
from nltk.tokenize import sent_tokenize
nltk.download("punkt", quiet=True)
import torch
import torch.nn as nn
import torch.nn.functional as F
from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
class Projector(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
@dataclass
class SONARLLMGenerationConfig:
# Outer sentence-level beam
sentence_beam_size: int = 4
latent_samples_per_step: int = 4 # M latent variants per active beam state
# Token-level decoder params
decoder_beam_size: int = 5 # default in fairseq2
decoder_temperature: float = 1.0 # default in fairseq2
normalize_sentence_scores: bool = True # False → sum of token log-probs
decoder_max_len: int = 256
# Latent sampling
temperature: float = 0.4
latent_top_p: Optional[float] = None # 0<p<=1 or None for Gaussian
temperature_mode: str = "relative" # "absolute" | "relative"
# Repetition control in latent space
repetition_penalty: float = 0.0
repetition_memory: int = 0
# Termination
max_sentences: int = 32
eos_threshold: float = 0.98
class SONARLLMGenerator(torch.nn.Module):
"""Sentence-level beam over latent reversed embeddings using SONAR decoder.
For each step:
- Run LLaMA on the sentence embedding history to get final hidden.
- Sample multiple latent directions (temperature/latent_top_p, with repetition penalty).
- Project to `reversed_emb` and decode text via SONAR decoder.
- Score each candidate using decoder sentence logprob (+ optional shaping).
- Keep top `sentence_beam_size` states and continue until EOS or max sentences.
This class does NOT modify existing project files and can be used standalone.
"""
def __init__(
self,
llama_model: nn.Module,
forward_proj: nn.Module,
reverse_proj: nn.Module,
sonar_decoder: EmbeddingToTextModelPipeline,
t2vec_model: TextToEmbeddingModelPipeline,
device: torch.device,
add_begin: bool = False,
) -> None:
super().__init__()
self.llama_model = llama_model.eval()
self.forward_proj = forward_proj.eval()
self.reverse_proj = reverse_proj.eval()
self.sonar_decoder = sonar_decoder.eval()
self.t2vec = t2vec_model.eval()
self.device = device
self.add_begin = add_begin
@torch.no_grad()
def generate(self, prefix_text: str, eos_emb: torch.Tensor, cfg: Optional[SONARLLMGenerationConfig] = None) -> str:
# Normalize and attach config to the instance for helper use
if cfg is None:
cfg = SONARLLMGenerationConfig()
self._cfg = cfg
sents = sent_tokenize(prefix_text)
if self.add_begin:
sents = ["Begin of text."] + sents
if len(sents) == 0:
sents = [prefix_text.strip()]
# Initialize prefix embeddings
emb_seq = self.t2vec.predict(sents, source_lang="eng_Latn").to(self.device)
# Beam state tuple: (sentences, embeddings_seq, cumulative_score, recent_dirs)
beams: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [
(sents[:], emb_seq, 0.0, [])
]
steps = 0
while steps < self._cfg.max_sentences:
steps += 1
candidates: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
for (hist_sents, hist_emb, score, recent_dirs) in beams:
candidates.extend(
self._expand_beam_state(hist_sents, hist_emb, score, recent_dirs, eos_emb)
)
# Keep top-k beams
if len(candidates) == 0:
break
candidates.sort(key=lambda b: b[2], reverse=True)
beams = candidates[: int(self._cfg.sentence_beam_size)]
# If all beams look ended by EOS threshold, stop early
if self._all_close_to_eos(beams, eos_emb):
break
best = max(beams, key=lambda b: b[2])
result = self._join_sentences(best[0])
if self.add_begin:
result = result[len("Begin of text."):]
return result
# --- internals ---
@torch.no_grad()
def _forward_hidden(self, emb_seq: torch.Tensor) -> torch.Tensor:
proj = self.forward_proj(emb_seq.unsqueeze(0)) if emb_seq.ndim == 2 else self.forward_proj(emb_seq)
out = self.llama_model(inputs_embeds=proj, output_hidden_states=True)
hidden = out.hidden_states[-1]
return hidden[0, -1, :]
def _join_sentences(self, sents: List[str]) -> str:
return " ".join(sents)
def _update_recent_dirs(
self, recent: List[torch.Tensor], u: torch.Tensor, memory_cap: int
) -> List[torch.Tensor]:
if memory_cap <= 0:
return recent
if not torch.isfinite(u).all():
return recent
new_recent = recent + [u.detach().to("cpu")]
if len(new_recent) > int(memory_cap):
new_recent = new_recent[-int(memory_cap) :]
return new_recent
def _sample_noise_direction(
self, final_hidden: torch.Tensor, recent_dirs: List[torch.Tensor]
) -> torch.Tensor:
g = torch.randn_like(final_hidden)
if (
self._cfg.repetition_penalty is not None
and float(self._cfg.repetition_penalty) != 1.0
and self._cfg.repetition_memory > 0
and len(recent_dirs) > 0
):
g = self._apply_repetition_penalty_to_direction(
g, float(self._cfg.repetition_penalty), int(self._cfg.repetition_memory), recent_dirs
)
return g / (g.norm(p=2) + 1e-12)
def _sample_noise(
self, final_hidden: torch.Tensor, dir_unit: torch.Tensor
) -> torch.Tensor:
t = float(self._cfg.temperature)
if t <= 0.0:
return torch.zeros_like(final_hidden)
if self._cfg.temperature_mode not in ("absolute", "relative"):
raise ValueError(f"Unsupported temperature_mode: {self._cfg.temperature_mode}")
if self._cfg.temperature_mode == "absolute":
sigma = torch.tensor(t, device=final_hidden.device, dtype=final_hidden.dtype)
else:
rms = torch.sqrt(torch.mean(final_hidden.to(torch.float32) ** 2))
rms = torch.clamp(rms, min=1e-12).to(dtype=final_hidden.dtype, device=final_hidden.device)
sigma = rms * t
top_p = self._cfg.latent_top_p
if top_p is None:
top_p = 1.0
return self._sample_truncated_normal_like(final_hidden, float(top_p), sigma, dir_unit)
def _sample_truncated_normal_like(
self, base_vector: torch.Tensor, top_p: float, sigma: torch.Tensor, dir_unit: torch.Tensor
) -> torch.Tensor:
# Wilson–Hilferty approximation for ChiSquare quantiles
dim = base_vector.numel()
device = base_vector.device
u = torch.rand((), device=device, dtype=torch.float32)
p = torch.clamp(u * float(top_p), min=1e-12, max=1.0 - 1e-12)
k = torch.tensor(float(dim), device=device, dtype=torch.float32)
z = torch.sqrt(torch.tensor(2.0, device=device, dtype=torch.float32)) * torch.special.erfinv(2.0 * p - 1.0)
term = 1.0 - 2.0 / (9.0 * k) + z * torch.sqrt(2.0 / (9.0 * k))
term = torch.clamp(term, min=1e-12)
s = k * (term ** 3)
r = torch.sqrt(torch.clamp(s, min=1e-12)).to(dtype=base_vector.dtype)
return dir_unit * (r * sigma)
def _expand_beam_state(
self,
hist_sents: List[str],
hist_emb: torch.Tensor,
score: float,
recent_dirs: List[torch.Tensor],
eos_emb: torch.Tensor,
) -> List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]]:
"""Expand one beam state into candidate next states.
Returns a list of (new_hist_sents, new_hist_emb, new_score, new_recent_dirs).
"""
final_hidden = self._forward_hidden(hist_emb)
out: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = []
for _ in range(max(1, int(self._cfg.latent_samples_per_step))):
dir_unit = self._sample_noise_direction(final_hidden, recent_dirs)
noise = self._sample_noise(final_hidden, dir_unit)
h_perturbed = final_hidden + noise
z = self.reverse_proj(h_perturbed.unsqueeze(0))
texts, scores = self.sonar_decoder.predict(
z,
target_lang="eng_Latn",
beam_size=int(self._cfg.decoder_beam_size),
normalize_scores=bool(self._cfg.normalize_sentence_scores),
max_seq_len=self._cfg.decoder_max_len,
temperature=float(self._cfg.decoder_temperature),
return_scores=True,
)
text = texts[0]
sent_logprob = float(scores[0])
z_re = self.t2vec.predict([text], source_lang="eng_Latn").to(self.device)
cand_score = score + sent_logprob
new_recent = self._update_recent_dirs(recent_dirs, dir_unit, self._cfg.repetition_memory)
new_hist_sents = hist_sents + [text]
new_hist_emb = torch.cat([hist_emb, z_re], dim=0)
out.append((new_hist_sents, new_hist_emb, cand_score, new_recent))
return out
def _apply_repetition_penalty_to_direction(
self, g: torch.Tensor, penalty: float, memory_cap: int, recent_dirs: List[torch.Tensor]
) -> torch.Tensor:
"""Mean-shift (A+) repetition penalty in latent direction space.
- penalty is clamped to [0, 1].
- penalty = 0 → no shift (q = 0.5).
- penalty = 1 → maximum shift (q ≈ q_min).
Mapping: q = 0.5^(1-penalty) * q_min^(penalty), beta = Phi^{-1}(1 - q),
and we set g' = g - beta * b_unit, where b_unit is the normalized average of recent directions.
"""
if memory_cap <= 0 or len(recent_dirs) == 0:
return g
# Aggregate and normalize recent directions
B = torch.stack(
[u.to(device=g.device, dtype=g.dtype) for u in recent_dirs[-int(memory_cap):]], dim=0
)
b = B.mean(dim=0)
bn = b.norm(p=2)
if not torch.isfinite(bn) or bn <= 1e-12:
return g
b_unit = b / bn
# Clamp and map penalty → beta via q
rp = float(penalty)
if rp < 0.0:
rp = 0.0
if rp > 1.0:
rp = 1.0
q_min = 1e-12
log_q = (1.0 - rp) * torch.log(torch.tensor(0.5, device=g.device, dtype=torch.float32))
log_q = log_q + rp * torch.log(torch.tensor(q_min, device=g.device, dtype=torch.float32))
q = torch.exp(log_q)
p = torch.clamp(1.0 - q, 1e-12, 1.0 - 1e-12)
beta = torch.sqrt(torch.tensor(2.0, device=g.device, dtype=g.dtype)) * torch.special.erfinv(2.0 * p - 1.0)
beta = torch.clamp(beta, 0.0, 7.5)
return g - (beta * b_unit)
def _all_close_to_eos(self, beams, eos_emb: torch.Tensor) -> bool:
for (_, emb, _, _) in beams:
last = emb[-1:, :]
sim = F.cosine_similarity(last, eos_emb, dim=1).item()
if sim < float(self._cfg.eos_threshold):
return False
return True
# --- factory ---
@classmethod
def load_from_checkpoint(
cls,
checkpoint_dir: str,
device: Optional[torch.device] = None,
generation_config: Optional[SONARLLMGenerationConfig] = None,
) -> "SONARLLMGenerator":
"""Load generator from a folder with config.json and weights.
The folder is expected to contain:
- config.json (with keys: pretrained_model_name_or_path, llama_config?, embed_dim)
- pytorch_model.bin (or model_state_dict inside the saved file)
"""
import json
import os
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg_path = os.path.join(checkpoint_dir, "config.json")
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
tokenizer = AutoTokenizer.from_pretrained(cfg["pretrained_model_name_or_path"])
tokenizer.pad_token = tokenizer.eos_token
llama_cfg_dict = cfg.get("llama_config", {})
if "vocab_size" not in llama_cfg_dict:
llama_cfg_dict["vocab_size"] = len(tokenizer)
# llama_cfg_dict["pad_token_id"] = tokenizer.pad_token_id
# llama_cfg_dict["bos_token_id"] = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 128000
# llama_cfg_dict["eos_token_id"] = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 128001
llama_cfg = LlamaConfig(**llama_cfg_dict) if "llama_config" in cfg else LlamaConfig()
llama_model = LlamaForCausalLM(llama_cfg).to(device).eval()
hidden_size = llama_cfg.hidden_size
embed_dim = cfg.get("embed_dim", 1024)
t2vec_model = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
).eval()
vec2text_model = EmbeddingToTextModelPipeline(
decoder="text_sonar_basic_decoder",
tokenizer="text_sonar_basic_encoder",
device=device,
).eval()
forward_projector = Projector(embed_dim, hidden_size).to(device).eval()
reverse_projector = Projector(hidden_size, embed_dim).to(device).eval()
gen = cls(
llama_model,
forward_projector,
reverse_projector,
vec2text_model,
t2vec_model,
device,
add_begin=cfg.get("add_begin", False),
)
# Load weights into generator to cover llama + projectors
ckpt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
state = torch.load(ckpt_bin, map_location=device, weights_only=True)
state = state.get("model_state_dict", state)
raw = gen.module if hasattr(gen, "module") else gen
raw.load_state_dict(state, strict=False)
return gen