|
|
import warnings |
|
|
from typing import Iterable, List, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from fairseq2.generation import ( |
|
|
BeamSearchSeq2SeqGenerator, |
|
|
Sampler, |
|
|
SamplingSeq2SeqGenerator, |
|
|
Seq2SeqGenerator, |
|
|
SequenceToTextConverter, |
|
|
) |
|
|
|
|
|
from sonar.inference_pipelines.utils import add_progress_bar |
|
|
from sonar.inference_pipelines.text import ( |
|
|
EmbeddingToTextModelPipeline as _BaseEmbeddingToTextModelPipeline, |
|
|
) |
|
|
from fairseq2.data.data_pipeline import read_sequence |
|
|
|
|
|
|
|
|
class EmbeddingToTextModelPipeline(_BaseEmbeddingToTextModelPipeline): |
|
|
"""Drop-in replacement that can also return sentence log-probabilities via return_scores. |
|
|
|
|
|
- When return_scores=False (default), behaves exactly like the base pipeline and returns List[str]. |
|
|
- When return_scores=True, returns a tuple (List[str], List[float]) where each float is the |
|
|
hypothesis score from fairseq2 (sum of token log-probabilities if normalize_scores=False, |
|
|
otherwise length-normalized per fairseq2 semantics). |
|
|
""" |
|
|
|
|
|
@torch.inference_mode() |
|
|
def predict( |
|
|
self, |
|
|
inputs: torch.Tensor, |
|
|
target_lang: str, |
|
|
batch_size: int = 5, |
|
|
progress_bar: bool = False, |
|
|
sampler: Optional[Sampler] = None, |
|
|
return_scores: bool = False, |
|
|
**generator_kwargs, |
|
|
): |
|
|
if sampler is not None: |
|
|
generator: Seq2SeqGenerator = SamplingSeq2SeqGenerator( |
|
|
self.model, sampler, **generator_kwargs |
|
|
) |
|
|
else: |
|
|
generator = BeamSearchSeq2SeqGenerator(self.model, **generator_kwargs) |
|
|
|
|
|
converter = SequenceToTextConverter( |
|
|
generator, |
|
|
self.tokenizer, |
|
|
task="translation", |
|
|
target_lang=target_lang, |
|
|
) |
|
|
|
|
|
def _do_translate(src_tensors: List[torch.Tensor]): |
|
|
texts, gen_out = converter.batch_convert( |
|
|
torch.stack(src_tensors).to(self.device), None |
|
|
) |
|
|
if return_scores: |
|
|
scores: List[float] = [] |
|
|
for hyps in gen_out.hypotheses: |
|
|
if len(hyps) == 0 or hyps[0].score is None: |
|
|
scores.append(0.0) |
|
|
else: |
|
|
scores.append(float(hyps[0].score)) |
|
|
return texts, scores |
|
|
return texts |
|
|
|
|
|
pipeline: Iterable = ( |
|
|
read_sequence(list(inputs)) |
|
|
.bucket(batch_size) |
|
|
.map(_do_translate) |
|
|
.and_return() |
|
|
) |
|
|
|
|
|
if progress_bar: |
|
|
pipeline = add_progress_bar(pipeline, inputs=inputs, batch_size=batch_size) |
|
|
|
|
|
results: List = list(iter(pipeline)) |
|
|
|
|
|
if not return_scores: |
|
|
|
|
|
return [text for batch_texts in results for text in batch_texts] |
|
|
|
|
|
|
|
|
all_texts: List[str] = [] |
|
|
all_scores: List[float] = [] |
|
|
for batch in results: |
|
|
batch_texts, batch_scores = batch |
|
|
all_texts.extend(batch_texts) |
|
|
all_scores.extend(batch_scores) |
|
|
return all_texts, all_scores |
|
|
|
|
|
|
|
|
|