from transformers import Pipeline import torch from .feature_extraction_antispoofing import AntispoofingFeatureExtractor class AntispoofingPipeline(Pipeline): def __init__(self, model, **kwargs): super().__init__(model=model, **kwargs) self.feature_extractor = AntispoofingFeatureExtractor() def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} postprocess_kwargs = {} if "sampling_rate" in kwargs: preprocess_kwargs["sampling_rate"] = kwargs["sampling_rate"] return preprocess_kwargs, {}, postprocess_kwargs def preprocess(self, audio, sampling_rate=16000): audio = self.feature_extractor(audio)['input_values'] inputs = {"input_values": audio} return inputs def _forward(self, model_inputs): outputs = self.model(**model_inputs) return outputs def postprocess(self, model_outputs): logits = model_outputs['logits'] probs = torch.nn.functional.softmax(logits, dim=-1) predicted_class = torch.argmax(probs, dim=-1).item() confidence = probs[0][predicted_class].item() return { "label": self.model.config.id2label[predicted_class], "logits": logits.tolist(), "score": confidence, "all_scores": { self.model.config.id2label[i]: probs[0][i].item() for i in range(len(probs[0])) } }