import time from dataclasses import dataclass from typing import List, Optional, Tuple import nltk from nltk.tokenize import sent_tokenize nltk.download("punkt", quiet=True) import torch import torch.nn as nn import torch.nn.functional as F from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline class Projector(nn.Module): def __init__(self, in_dim: int, out_dim: int): super().__init__() self.linear = nn.Linear(in_dim, out_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @dataclass class SONARLLMGenerationConfig: # Outer sentence-level beam sentence_beam_size: int = 4 latent_samples_per_step: int = 4 # M latent variants per active beam state # Token-level decoder params decoder_beam_size: int = 5 # default in fairseq2 decoder_temperature: float = 1.0 # default in fairseq2 normalize_sentence_scores: bool = True # False → sum of token log-probs decoder_max_len: int = 256 # Latent sampling temperature: float = 0.4 latent_top_p: Optional[float] = None # 0
None: super().__init__() self.llama_model = llama_model.eval() self.forward_proj = forward_proj.eval() self.reverse_proj = reverse_proj.eval() self.sonar_decoder = sonar_decoder.eval() self.t2vec = t2vec_model.eval() self.device = device self.add_begin = add_begin @torch.no_grad() def generate(self, prefix_text: str, eos_emb: torch.Tensor, cfg: Optional[SONARLLMGenerationConfig] = None) -> str: # Normalize and attach config to the instance for helper use if cfg is None: cfg = SONARLLMGenerationConfig() self._cfg = cfg sents = sent_tokenize(prefix_text) if self.add_begin: sents = ["Begin of text."] + sents if len(sents) == 0: sents = [prefix_text.strip()] # Initialize prefix embeddings emb_seq = self.t2vec.predict(sents, source_lang="eng_Latn").to(self.device) # Beam state tuple: (sentences, embeddings_seq, cumulative_score, recent_dirs) beams: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [ (sents[:], emb_seq, 0.0, []) ] steps = 0 while steps < self._cfg.max_sentences: steps += 1 candidates: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [] for (hist_sents, hist_emb, score, recent_dirs) in beams: candidates.extend( self._expand_beam_state(hist_sents, hist_emb, score, recent_dirs, eos_emb) ) # Keep top-k beams if len(candidates) == 0: break candidates.sort(key=lambda b: b[2], reverse=True) beams = candidates[: int(self._cfg.sentence_beam_size)] # If all beams look ended by EOS threshold, stop early if self._all_close_to_eos(beams, eos_emb): break best = max(beams, key=lambda b: b[2]) result = self._join_sentences(best[0]) if self.add_begin: result = result[len("Begin of text."):] return result # --- internals --- @torch.no_grad() def _forward_hidden(self, emb_seq: torch.Tensor) -> torch.Tensor: proj = self.forward_proj(emb_seq.unsqueeze(0)) if emb_seq.ndim == 2 else self.forward_proj(emb_seq) out = self.llama_model(inputs_embeds=proj, output_hidden_states=True) hidden = out.hidden_states[-1] return hidden[0, -1, :] def _join_sentences(self, sents: List[str]) -> str: return " ".join(sents) def _update_recent_dirs( self, recent: List[torch.Tensor], u: torch.Tensor, memory_cap: int ) -> List[torch.Tensor]: if memory_cap <= 0: return recent if not torch.isfinite(u).all(): return recent new_recent = recent + [u.detach().to("cpu")] if len(new_recent) > int(memory_cap): new_recent = new_recent[-int(memory_cap) :] return new_recent def _sample_noise_direction( self, final_hidden: torch.Tensor, recent_dirs: List[torch.Tensor] ) -> torch.Tensor: g = torch.randn_like(final_hidden) if ( self._cfg.repetition_penalty is not None and float(self._cfg.repetition_penalty) != 1.0 and self._cfg.repetition_memory > 0 and len(recent_dirs) > 0 ): g = self._apply_repetition_penalty_to_direction( g, float(self._cfg.repetition_penalty), int(self._cfg.repetition_memory), recent_dirs ) return g / (g.norm(p=2) + 1e-12) def _sample_noise( self, final_hidden: torch.Tensor, dir_unit: torch.Tensor ) -> torch.Tensor: t = float(self._cfg.temperature) if t <= 0.0: return torch.zeros_like(final_hidden) if self._cfg.temperature_mode not in ("absolute", "relative"): raise ValueError(f"Unsupported temperature_mode: {self._cfg.temperature_mode}") if self._cfg.temperature_mode == "absolute": sigma = torch.tensor(t, device=final_hidden.device, dtype=final_hidden.dtype) else: rms = torch.sqrt(torch.mean(final_hidden.to(torch.float32) ** 2)) rms = torch.clamp(rms, min=1e-12).to(dtype=final_hidden.dtype, device=final_hidden.device) sigma = rms * t top_p = self._cfg.latent_top_p if top_p is None: top_p = 1.0 return self._sample_truncated_normal_like(final_hidden, float(top_p), sigma, dir_unit) def _sample_truncated_normal_like( self, base_vector: torch.Tensor, top_p: float, sigma: torch.Tensor, dir_unit: torch.Tensor ) -> torch.Tensor: # Wilson–Hilferty approximation for ChiSquare quantiles dim = base_vector.numel() device = base_vector.device u = torch.rand((), device=device, dtype=torch.float32) p = torch.clamp(u * float(top_p), min=1e-12, max=1.0 - 1e-12) k = torch.tensor(float(dim), device=device, dtype=torch.float32) z = torch.sqrt(torch.tensor(2.0, device=device, dtype=torch.float32)) * torch.special.erfinv(2.0 * p - 1.0) term = 1.0 - 2.0 / (9.0 * k) + z * torch.sqrt(2.0 / (9.0 * k)) term = torch.clamp(term, min=1e-12) s = k * (term ** 3) r = torch.sqrt(torch.clamp(s, min=1e-12)).to(dtype=base_vector.dtype) return dir_unit * (r * sigma) def _expand_beam_state( self, hist_sents: List[str], hist_emb: torch.Tensor, score: float, recent_dirs: List[torch.Tensor], eos_emb: torch.Tensor, ) -> List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]]: """Expand one beam state into candidate next states. Returns a list of (new_hist_sents, new_hist_emb, new_score, new_recent_dirs). """ final_hidden = self._forward_hidden(hist_emb) out: List[Tuple[List[str], torch.Tensor, float, List[torch.Tensor]]] = [] for _ in range(max(1, int(self._cfg.latent_samples_per_step))): dir_unit = self._sample_noise_direction(final_hidden, recent_dirs) noise = self._sample_noise(final_hidden, dir_unit) h_perturbed = final_hidden + noise z = self.reverse_proj(h_perturbed.unsqueeze(0)) texts, scores = self.sonar_decoder.predict( z, target_lang="eng_Latn", beam_size=int(self._cfg.decoder_beam_size), normalize_scores=bool(self._cfg.normalize_sentence_scores), max_seq_len=self._cfg.decoder_max_len, temperature=float(self._cfg.decoder_temperature), return_scores=True, ) text = texts[0] sent_logprob = float(scores[0]) z_re = self.t2vec.predict([text], source_lang="eng_Latn").to(self.device) cand_score = score + sent_logprob new_recent = self._update_recent_dirs(recent_dirs, dir_unit, self._cfg.repetition_memory) new_hist_sents = hist_sents + [text] new_hist_emb = torch.cat([hist_emb, z_re], dim=0) out.append((new_hist_sents, new_hist_emb, cand_score, new_recent)) return out def _apply_repetition_penalty_to_direction( self, g: torch.Tensor, penalty: float, memory_cap: int, recent_dirs: List[torch.Tensor] ) -> torch.Tensor: """Mean-shift (A+) repetition penalty in latent direction space. - penalty is clamped to [0, 1]. - penalty = 0 → no shift (q = 0.5). - penalty = 1 → maximum shift (q ≈ q_min). Mapping: q = 0.5^(1-penalty) * q_min^(penalty), beta = Phi^{-1}(1 - q), and we set g' = g - beta * b_unit, where b_unit is the normalized average of recent directions. """ if memory_cap <= 0 or len(recent_dirs) == 0: return g # Aggregate and normalize recent directions B = torch.stack( [u.to(device=g.device, dtype=g.dtype) for u in recent_dirs[-int(memory_cap):]], dim=0 ) b = B.mean(dim=0) bn = b.norm(p=2) if not torch.isfinite(bn) or bn <= 1e-12: return g b_unit = b / bn # Clamp and map penalty → beta via q rp = float(penalty) if rp < 0.0: rp = 0.0 if rp > 1.0: rp = 1.0 q_min = 1e-12 log_q = (1.0 - rp) * torch.log(torch.tensor(0.5, device=g.device, dtype=torch.float32)) log_q = log_q + rp * torch.log(torch.tensor(q_min, device=g.device, dtype=torch.float32)) q = torch.exp(log_q) p = torch.clamp(1.0 - q, 1e-12, 1.0 - 1e-12) beta = torch.sqrt(torch.tensor(2.0, device=g.device, dtype=g.dtype)) * torch.special.erfinv(2.0 * p - 1.0) beta = torch.clamp(beta, 0.0, 7.5) return g - (beta * b_unit) def _all_close_to_eos(self, beams, eos_emb: torch.Tensor) -> bool: for (_, emb, _, _) in beams: last = emb[-1:, :] sim = F.cosine_similarity(last, eos_emb, dim=1).item() if sim < float(self._cfg.eos_threshold): return False return True # --- factory --- @classmethod def load_from_checkpoint( cls, checkpoint_dir: str, device: Optional[torch.device] = None, generation_config: Optional[SONARLLMGenerationConfig] = None, ) -> "SONARLLMGenerator": """Load generator from a folder with config.json and weights. The folder is expected to contain: - config.json (with keys: pretrained_model_name_or_path, llama_config?, embed_dim) - pytorch_model.bin (or model_state_dict inside the saved file) """ import json import os from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM from .embedding_to_text_with_scores import EmbeddingToTextModelPipeline from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cfg_path = os.path.join(checkpoint_dir, "config.json") with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) tokenizer = AutoTokenizer.from_pretrained(cfg["pretrained_model_name_or_path"]) tokenizer.pad_token = tokenizer.eos_token llama_cfg_dict = cfg.get("llama_config", {}) if "vocab_size" not in llama_cfg_dict: llama_cfg_dict["vocab_size"] = len(tokenizer) # llama_cfg_dict["pad_token_id"] = tokenizer.pad_token_id # llama_cfg_dict["bos_token_id"] = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 128000 # llama_cfg_dict["eos_token_id"] = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 128001 llama_cfg = LlamaConfig(**llama_cfg_dict) if "llama_config" in cfg else LlamaConfig() llama_model = LlamaForCausalLM(llama_cfg).to(device).eval() hidden_size = llama_cfg.hidden_size embed_dim = cfg.get("embed_dim", 1024) t2vec_model = TextToEmbeddingModelPipeline( encoder="text_sonar_basic_encoder", tokenizer="text_sonar_basic_encoder", device=device, ).eval() vec2text_model = EmbeddingToTextModelPipeline( decoder="text_sonar_basic_decoder", tokenizer="text_sonar_basic_encoder", device=device, ).eval() forward_projector = Projector(embed_dim, hidden_size).to(device).eval() reverse_projector = Projector(hidden_size, embed_dim).to(device).eval() gen = cls( llama_model, forward_projector, reverse_projector, vec2text_model, t2vec_model, device, add_begin=cfg.get("add_begin", False), ) # Load weights into generator to cover llama + projectors ckpt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin") state = torch.load(ckpt_bin, map_location=device, weights_only=True) state = state.get("model_state_dict", state) raw = gen.module if hasattr(gen, "module") else gen raw.load_state_dict(state, strict=False) return gen