# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn import torch.nn.functional as F class AudioTower(nn.Module): def __init__(self, audio_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.audio_tower_name = audio_tower self.cfg_only = None def forward(self, sounds): if type(sounds) is list: sound_features = [] audio_output_lengths = [] for sound in sounds: if hasattr(sound, "input_features"): sound = sound["input_features"] sound_feature = self.audio_tower(sound) sound_feature = sound_feature.last_hidden_state sound_feature = sound_feature.to(sound.dtype) sound_features.append(sound_feature) audio_output_lengths.append(sound_feature.shape[1]) sound_features = torch.cat(sound_features, dim=1).squeeze(0) else: raise NotImplementedError("Not implemented for this encoder") return sound_features, audio_output_lengths @property def dummy_feature(self): return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) @property def dtype(self): return self.audio_tower.dtype @property def config(self): if self.is_loaded: return self.audio_tower.config else: return self.cfg_only @property def device(self): return self.audio_tower.device @property def hidden_size(self): return self.config.hidden_size