import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_antispoofing import DF_Arena_1B_Config from .backbone import DF_Arena_1B from .feature_extraction_antispoofing import AntispoofingFeatureExtractor class DF_Arena_1B_Antispoofing(PreTrainedModel): config_class = DF_Arena_1B_Config def __init__(self, config: DF_Arena_1B_Config): super().__init__(config) self.feature_extractor = AntispoofingFeatureExtractor() # your backbone here (CNN/TDNN/Wav2Vec front-end, etc.) self.backbone = DF_Arena_1B() self.post_init() def forward(self, input_values, attention_mask=None): # input_values: (batch, time) float32 waveform @ config.sample_rate logits = self.backbone(input_values) return {"logits": logits}