from transformers import PreTrainedModel from .model import WavJEPA from .configuration_wavjepa import WavJEPAConfig from .audio_extractor import ConvFeatureExtractor import torch from typing import Union class WavJEPAModel(PreTrainedModel): config_class = WavJEPAConfig def __init__(self, config): super().__init__(config) self.model = WavJEPA( feature_extractor = ConvFeatureExtractor( conv_layers_spec = eval(config.extractor_config['conv_layers_spec']), in_channels = config.extractor_config['in_channels'], dropout = config.extractor_config['dropout'], mode = config.extractor_config['mode'], conv_bias = config.extractor_config['conv_bias'], depthwise = config.extractor_config['depthwise'], ), transformer_encoder_layers_cfg = config.encoder_layers_cfg, transformer_encoder_cfg = config.encoder_cfg, transformer_decoder_layers_cfg = config.decoder_layers_cfg, transformer_decoder_cfg = config.decoder_cfg, size = config.model_size, ) def forward(self, tensor) -> Union[torch.Tensor, torch.Tensor]: return self.model.get_audio_representation(tensor)