|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from .configuration_antispoofing import DF_Arena_1B_Config |
|
|
from .backbone import DF_Arena_1B |
|
|
from .feature_extraction_antispoofing import AntispoofingFeatureExtractor |
|
|
|
|
|
class DF_Arena_1B_Antispoofing(PreTrainedModel): |
|
|
config_class = DF_Arena_1B_Config |
|
|
|
|
|
def __init__(self, config: DF_Arena_1B_Config): |
|
|
super().__init__(config) |
|
|
self.feature_extractor = AntispoofingFeatureExtractor() |
|
|
|
|
|
self.backbone = DF_Arena_1B() |
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_values, attention_mask=None): |
|
|
|
|
|
logits = self.backbone(input_values) |
|
|
return {"logits": logits} |
|
|
|