import warnings from typing import Iterable, List, Optional import torch from fairseq2.generation import ( BeamSearchSeq2SeqGenerator, Sampler, SamplingSeq2SeqGenerator, Seq2SeqGenerator, SequenceToTextConverter, ) from sonar.inference_pipelines.utils import add_progress_bar from sonar.inference_pipelines.text import ( EmbeddingToTextModelPipeline as _BaseEmbeddingToTextModelPipeline, ) from fairseq2.data.data_pipeline import read_sequence class EmbeddingToTextModelPipeline(_BaseEmbeddingToTextModelPipeline): """Drop-in replacement that can also return sentence log-probabilities via return_scores. - When return_scores=False (default), behaves exactly like the base pipeline and returns List[str]. - When return_scores=True, returns a tuple (List[str], List[float]) where each float is the hypothesis score from fairseq2 (sum of token log-probabilities if normalize_scores=False, otherwise length-normalized per fairseq2 semantics). """ @torch.inference_mode() def predict( self, inputs: torch.Tensor, target_lang: str, batch_size: int = 5, progress_bar: bool = False, sampler: Optional[Sampler] = None, return_scores: bool = False, **generator_kwargs, ): if sampler is not None: generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator( self.model, sampler, **generator_kwargs ) else: generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) converter = SequenceToTextConverter( generator, self.tokenizer, task="translation", target_lang=target_lang, ) def _do_translate(src_tensors: List[torch.Tensor]): texts, gen_out = converter.batch_convert( torch.stack(src_tensors).to(self.device), None ) if return_scores: scores: List[float] = [] for hyps in gen_out.hypotheses: if len(hyps) == 0 or hyps[0].score is None: scores.append(0.0) else: scores.append(float(hyps[0].score)) return texts, scores return texts pipeline: Iterable = ( read_sequence(list(inputs)) .bucket(batch_size) .map(_do_translate) .and_return() ) if progress_bar: pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size) results: List = list(iter(pipeline)) if not return_scores: # results is List[List[str]] → flatten return [text for batch_texts in results for text in batch_texts] # results is List[Tuple[List[str], List[float]]] → flatten both all_texts: List[str] = [] all_scores: List[float] = [] for batch in results: batch_texts, batch_scores = batch all_texts.extend(batch_texts) all_scores.extend(batch_scores) return all_texts, all_scores