| from transformers import Pipeline | |
| import torch | |
| from .feature_extraction_antispoofing import AntispoofingFeatureExtractor | |
| class AntispoofingPipeline(Pipeline): | |
| def __init__(self, model, **kwargs): | |
| super().__init__(model=model, **kwargs) | |
| self.feature_extractor = AntispoofingFeatureExtractor() | |
| def _sanitize_parameters(self, **kwargs): | |
| preprocess_kwargs = {} | |
| postprocess_kwargs = {} | |
| if "sampling_rate" in kwargs: | |
| preprocess_kwargs["sampling_rate"] = kwargs["sampling_rate"] | |
| return preprocess_kwargs, {}, postprocess_kwargs | |
| def preprocess(self, audio, sampling_rate=16000): | |
| audio = self.feature_extractor(audio)['input_values'] | |
| inputs = {"input_values": audio} | |
| return inputs | |
| def _forward(self, model_inputs): | |
| outputs = self.model(**model_inputs) | |
| return outputs | |
| def postprocess(self, model_outputs): | |
| logits = model_outputs['logits'] | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| predicted_class = torch.argmax(probs, dim=-1).item() | |
| confidence = probs[0][predicted_class].item() | |
| return { | |
| "label": self.model.config.id2label[predicted_class], | |
| "logits": logits.tolist(), | |
| "score": confidence, | |
| "all_scores": { | |
| self.model.config.id2label[i]: probs[0][i].item() | |
| for i in range(len(probs[0])) | |
| } | |
| } |