|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
|
|
|
import nltk |
|
|
from nltk.tokenize import sent_tokenize |
|
|
nltk.download("punkt", quiet=True) |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline |
|
|
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline |
|
|
|
|
|
class Projector(nn.Module): |
|
|
def __init__(self, in_dim: int, out_dim: int): |
|
|
super().__init__() |
|
|
self.linear = nn.Linear(in_dim, out_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.linear(x) |
|
|
|
|
|
@dataclass |
|
|
class SONARLLMGenerationConfig: |
|
|
|
|
|
sentence_beam_size: int = 4 |
|
|
latent_samples_per_step: int = 4 |
|
|
|
|
|
|
|
|
decoder_beam_size: int = 5 |
|
|
decoder_temperature: float = 1.0 |
|
|
normalize_sentence_scores: bool = True |
|
|
decoder_max_len: int = 256 |
|
|
|
|
|
|
|
|
temperature: float = 0.4 |
|
|
latent_top_p: Optional[float] = None |
|
|
temperature_mode: str = "relative" |
|
|
|
|
|
|
|
|
repetition_penalty: float = 0.0 |
|
|
repetition_memory: int = 0 |
|
|
|
|
|
|
|
|
max_sentences: int = 32 |
|
|
eos_threshold: float = 0.98 |
|
|
|
|
|
|
|
|
class SONARLLMGenerator(torch.nn.Module): |
|
|
"""Sentence-level beam over latent reversed embeddings using SONAR decoder. |
|
|
|
|
|
For each step: |
|
|
- Run LLaMA on the sentence embedding history to get final hidden. |
|
|
- Sample multiple latent directions (temperature/latent_top_p, with repetition penalty). |
|
|
- Project to `reversed_emb` and decode text via SONAR decoder. |
|
|
- Score each candidate using decoder sentence logprob (+ optional shaping). |
|
|
- Keep top `sentence_beam_size` states and continue until EOS or max sentences. |
|
|
|
|
|
This class does NOT modify existing project files and can be used standalone. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
llama_model: nn.Module, |
|
|
forward_proj: nn.Module, |
|
|
reverse_proj: nn.Module, |
|
|
sonar_decoder: EmbeddingToTextModelPipeline, |
|
|
t2vec_model: TextToEmbeddingModelPipeline, |
|
|
device: torch.device, |
|
|
add_begin: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.llama_model = llama_model.eval() |
|
|
self.forward_proj = forward_proj.eval() |
|
|
self.reverse_proj = reverse_proj.eval() |
|
|
self.sonar_decoder = sonar_decoder.eval() |
|
|
self.t2vec = t2vec_model.eval() |
|
|
self.device = device |
|
|
self.add_begin = add_begin |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, prefix_text: str, eos_emb: torch.Tensor, cfg: Optional[SONARLLMGenerationConfig] = None) -> str: |
|
|
|
|
|
if cfg is None: |
|
|
cfg = SONARLLMGenerationConfig() |
|
|
self._cfg = cfg |
|
|
sents = sent_tokenize(prefix_text) |
|
|
if self.add_begin: |
|
|
sents = ["Begin of text."] + sents |
|
|
|
|
|
if len(sents) == 0: |
|
|
sents = [prefix_text.strip()] |
|
|
|
|
|
|
|
|
emb_seq = self.t2vec.predict(sents, source_lang="eng_Latn").to(self.device) |
|
|
|
|
|
|
|
|
beams: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [ |
|
|
(sents[:], emb_seq, 0.0, []) |
|
|
] |
|
|
|
|
|
steps = 0 |
|
|
while steps < self._cfg.max_sentences: |
|
|
steps += 1 |
|
|
candidates: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [] |
|
|
|
|
|
for (hist_sents, hist_emb, score, recent_dirs) in beams: |
|
|
candidates.extend( |
|
|
self._expand_beam_state(hist_sents, hist_emb, score, recent_dirs, eos_emb) |
|
|
) |
|
|
|
|
|
|
|
|
if len(candidates) == 0: |
|
|
break |
|
|
candidates.sort(key=lambda b: b[2], reverse=True) |
|
|
beams = candidates[: int(self._cfg.sentence_beam_size)] |
|
|
|
|
|
|
|
|
if self._all_close_to_eos(beams, eos_emb): |
|
|
break |
|
|
|
|
|
best = max(beams, key=lambda b: b[2]) |
|
|
result = self._join_sentences(best[0]) |
|
|
if self.add_begin: |
|
|
result = result[len("Begin of text."):] |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _forward_hidden(self, emb_seq: torch.Tensor) -> torch.Tensor: |
|
|
proj = self.forward_proj(emb_seq.unsqueeze(0)) if emb_seq.ndim == 2 else self.forward_proj(emb_seq) |
|
|
out = self.llama_model(inputs_embeds=proj, output_hidden_states=True) |
|
|
hidden = out.hidden_states[-1] |
|
|
return hidden[0, -1, :] |
|
|
|
|
|
def _join_sentences(self, sents: List[str]) -> str: |
|
|
return " ".join(sents) |
|
|
|
|
|
def _update_recent_dirs( |
|
|
self, recent: List[torch.Tensor], u: torch.Tensor, memory_cap: int |
|
|
) -> List[torch.Tensor]: |
|
|
if memory_cap <= 0: |
|
|
return recent |
|
|
if not torch.isfinite(u).all(): |
|
|
return recent |
|
|
new_recent = recent + [u.detach().to("cpu")] |
|
|
if len(new_recent) > int(memory_cap): |
|
|
new_recent = new_recent[-int(memory_cap) :] |
|
|
return new_recent |
|
|
|
|
|
def _sample_noise_direction( |
|
|
self, final_hidden: torch.Tensor, recent_dirs: List[torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
g = torch.randn_like(final_hidden) |
|
|
if ( |
|
|
self._cfg.repetition_penalty is not None |
|
|
and float(self._cfg.repetition_penalty) != 1.0 |
|
|
and self._cfg.repetition_memory > 0 |
|
|
and len(recent_dirs) > 0 |
|
|
): |
|
|
g = self._apply_repetition_penalty_to_direction( |
|
|
g, float(self._cfg.repetition_penalty), int(self._cfg.repetition_memory), recent_dirs |
|
|
) |
|
|
return g / (g.norm(p=2) + 1e-12) |
|
|
|
|
|
def _sample_noise( |
|
|
self, final_hidden: torch.Tensor, dir_unit: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
t = float(self._cfg.temperature) |
|
|
if t <= 0.0: |
|
|
return torch.zeros_like(final_hidden) |
|
|
|
|
|
if self._cfg.temperature_mode not in ("absolute", "relative"): |
|
|
raise ValueError(f"Unsupported temperature_mode: {self._cfg.temperature_mode}") |
|
|
|
|
|
if self._cfg.temperature_mode == "absolute": |
|
|
sigma = torch.tensor(t, device=final_hidden.device, dtype=final_hidden.dtype) |
|
|
else: |
|
|
rms = torch.sqrt(torch.mean(final_hidden.to(torch.float32) ** 2)) |
|
|
rms = torch.clamp(rms, min=1e-12).to(dtype=final_hidden.dtype, device=final_hidden.device) |
|
|
sigma = rms * t |
|
|
|
|
|
top_p = self._cfg.latent_top_p |
|
|
if top_p is None: |
|
|
top_p = 1.0 |
|
|
return self._sample_truncated_normal_like(final_hidden, float(top_p), sigma, dir_unit) |
|
|
|
|
|
def _sample_truncated_normal_like( |
|
|
self, base_vector: torch.Tensor, top_p: float, sigma: torch.Tensor, dir_unit: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
|
|
|
dim = base_vector.numel() |
|
|
device = base_vector.device |
|
|
u = torch.rand((), device=device, dtype=torch.float32) |
|
|
p = torch.clamp(u * float(top_p), min=1e-12, max=1.0 - 1e-12) |
|
|
k = torch.tensor(float(dim), device=device, dtype=torch.float32) |
|
|
z = torch.sqrt(torch.tensor(2.0, device=device, dtype=torch.float32)) * torch.special.erfinv(2.0 * p - 1.0) |
|
|
term = 1.0 - 2.0 / (9.0 * k) + z * torch.sqrt(2.0 / (9.0 * k)) |
|
|
term = torch.clamp(term, min=1e-12) |
|
|
s = k * (term ** 3) |
|
|
r = torch.sqrt(torch.clamp(s, min=1e-12)).to(dtype=base_vector.dtype) |
|
|
return dir_unit * (r * sigma) |
|
|
|
|
|
def _expand_beam_state( |
|
|
self, |
|
|
hist_sents: List[str], |
|
|
hist_emb: torch.Tensor, |
|
|
score: float, |
|
|
recent_dirs: List[torch.Tensor], |
|
|
eos_emb: torch.Tensor, |
|
|
) -> List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]]: |
|
|
"""Expand one beam state into candidate next states. |
|
|
|
|
|
Returns a list of (new_hist_sents, new_hist_emb, new_score, new_recent_dirs). |
|
|
""" |
|
|
final_hidden = self._forward_hidden(hist_emb) |
|
|
out: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [] |
|
|
|
|
|
for _ in range(max(1, int(self._cfg.latent_samples_per_step))): |
|
|
dir_unit = self._sample_noise_direction(final_hidden, recent_dirs) |
|
|
noise = self._sample_noise(final_hidden, dir_unit) |
|
|
h_perturbed = final_hidden + noise |
|
|
z = self.reverse_proj(h_perturbed.unsqueeze(0)) |
|
|
|
|
|
texts, scores = self.sonar_decoder.predict( |
|
|
z, |
|
|
target_lang="eng_Latn", |
|
|
beam_size=int(self._cfg.decoder_beam_size), |
|
|
normalize_scores=bool(self._cfg.normalize_sentence_scores), |
|
|
max_seq_len=self._cfg.decoder_max_len, |
|
|
temperature=float(self._cfg.decoder_temperature), |
|
|
return_scores=True, |
|
|
) |
|
|
text = texts[0] |
|
|
sent_logprob = float(scores[0]) |
|
|
|
|
|
z_re = self.t2vec.predict([text], source_lang="eng_Latn").to(self.device) |
|
|
|
|
|
cand_score = score + sent_logprob |
|
|
new_recent = self._update_recent_dirs(recent_dirs, dir_unit, self._cfg.repetition_memory) |
|
|
|
|
|
new_hist_sents = hist_sents + [text] |
|
|
new_hist_emb = torch.cat([hist_emb, z_re], dim=0) |
|
|
|
|
|
out.append((new_hist_sents, new_hist_emb, cand_score, new_recent)) |
|
|
|
|
|
return out |
|
|
|
|
|
def _apply_repetition_penalty_to_direction( |
|
|
self, g: torch.Tensor, penalty: float, memory_cap: int, recent_dirs: List[torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
"""Mean-shift (A+) repetition penalty in latent direction space. |
|
|
|
|
|
- penalty is clamped to [0, 1]. |
|
|
- penalty = 0 → no shift (q = 0.5). |
|
|
- penalty = 1 → maximum shift (q ≈ q_min). |
|
|
Mapping: q = 0.5^(1-penalty) * q_min^(penalty), beta = Phi^{-1}(1 - q), |
|
|
and we set g' = g - beta * b_unit, where b_unit is the normalized average of recent directions. |
|
|
""" |
|
|
if memory_cap <= 0 or len(recent_dirs) == 0: |
|
|
return g |
|
|
|
|
|
|
|
|
B = torch.stack( |
|
|
[u.to(device=g.device, dtype=g.dtype) for u in recent_dirs[-int(memory_cap):]], dim=0 |
|
|
) |
|
|
b = B.mean(dim=0) |
|
|
bn = b.norm(p=2) |
|
|
if not torch.isfinite(bn) or bn <= 1e-12: |
|
|
return g |
|
|
b_unit = b / bn |
|
|
|
|
|
|
|
|
rp = float(penalty) |
|
|
if rp < 0.0: |
|
|
rp = 0.0 |
|
|
if rp > 1.0: |
|
|
rp = 1.0 |
|
|
q_min = 1e-12 |
|
|
log_q = (1.0 - rp) * torch.log(torch.tensor(0.5, device=g.device, dtype=torch.float32)) |
|
|
log_q = log_q + rp * torch.log(torch.tensor(q_min, device=g.device, dtype=torch.float32)) |
|
|
q = torch.exp(log_q) |
|
|
p = torch.clamp(1.0 - q, 1e-12, 1.0 - 1e-12) |
|
|
beta = torch.sqrt(torch.tensor(2.0, device=g.device, dtype=g.dtype)) * torch.special.erfinv(2.0 * p - 1.0) |
|
|
beta = torch.clamp(beta, 0.0, 7.5) |
|
|
return g - (beta * b_unit) |
|
|
|
|
|
def _all_close_to_eos(self, beams, eos_emb: torch.Tensor) -> bool: |
|
|
for (_, emb, _, _) in beams: |
|
|
last = emb[-1:, :] |
|
|
sim = F.cosine_similarity(last, eos_emb, dim=1).item() |
|
|
if sim < float(self._cfg.eos_threshold): |
|
|
return False |
|
|
return True |
|
|
|
|
|
|
|
|
@classmethod |
|
|
def load_from_checkpoint( |
|
|
cls, |
|
|
checkpoint_dir: str, |
|
|
device: Optional[torch.device] = None, |
|
|
generation_config: Optional[SONARLLMGenerationConfig] = None, |
|
|
) -> "SONARLLMGenerator": |
|
|
"""Load generator from a folder with config.json and weights. |
|
|
|
|
|
The folder is expected to contain: |
|
|
- config.json (with keys: pretrained_model_name_or_path, llama_config?, embed_dim) |
|
|
- pytorch_model.bin (or model_state_dict inside the saved file) |
|
|
""" |
|
|
import json |
|
|
import os |
|
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM |
|
|
from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline |
|
|
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline |
|
|
|
|
|
if device is None: |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
cfg_path = os.path.join(checkpoint_dir, "config.json") |
|
|
with open(cfg_path, "r", encoding="utf-8") as f: |
|
|
cfg = json.load(f) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg["pretrained_model_name_or_path"]) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
llama_cfg_dict = cfg.get("llama_config", {}) |
|
|
if "vocab_size" not in llama_cfg_dict: |
|
|
llama_cfg_dict["vocab_size"] = len(tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
llama_cfg = LlamaConfig(**llama_cfg_dict) if "llama_config" in cfg else LlamaConfig() |
|
|
|
|
|
llama_model = LlamaForCausalLM(llama_cfg).to(device).eval() |
|
|
|
|
|
hidden_size = llama_cfg.hidden_size |
|
|
embed_dim = cfg.get("embed_dim", 1024) |
|
|
|
|
|
t2vec_model = TextToEmbeddingModelPipeline( |
|
|
encoder="text_sonar_basic_encoder", |
|
|
tokenizer="text_sonar_basic_encoder", |
|
|
device=device, |
|
|
).eval() |
|
|
|
|
|
vec2text_model = EmbeddingToTextModelPipeline( |
|
|
decoder="text_sonar_basic_decoder", |
|
|
tokenizer="text_sonar_basic_encoder", |
|
|
device=device, |
|
|
).eval() |
|
|
|
|
|
forward_projector = Projector(embed_dim, hidden_size).to(device).eval() |
|
|
reverse_projector = Projector(hidden_size, embed_dim).to(device).eval() |
|
|
|
|
|
gen = cls( |
|
|
llama_model, |
|
|
forward_projector, |
|
|
reverse_projector, |
|
|
vec2text_model, |
|
|
t2vec_model, |
|
|
device, |
|
|
add_begin=cfg.get("add_begin", False), |
|
|
) |
|
|
|
|
|
|
|
|
ckpt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin") |
|
|
state = torch.load(ckpt_bin, map_location=device, weights_only=True) |
|
|
state = state.get("model_state_dict", state) |
|
|
raw = gen.module if hasattr(gen, "module") else gen |
|
|
raw.load_state_dict(state, strict=False) |
|
|
|
|
|
return gen |
|
|
|
|
|
|
|
|
|