omnivinci / audio_encoder.py
leoye's picture
Initial commit
fd01e7c
raw
history blame
2.28 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class AudioTower(nn.Module):
def __init__(self, audio_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.audio_tower_name = audio_tower
self.cfg_only = None
def forward(self, sounds):
if type(sounds) is list:
sound_features = []
audio_output_lengths = []
for sound in sounds:
if hasattr(sound, "input_features"):
sound = sound["input_features"]
sound_feature = self.audio_tower(sound)
sound_feature = sound_feature.last_hidden_state
sound_feature = sound_feature.to(sound.dtype)
sound_features.append(sound_feature)
audio_output_lengths.append(sound_feature.shape[1])
sound_features = torch.cat(sound_features, dim=1).squeeze(0)
else:
raise NotImplementedError("Not implemented for this encoder")
return sound_features, audio_output_lengths
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.audio_tower.dtype
@property
def config(self):
if self.is_loaded:
return self.audio_tower.config
else:
return self.cfg_only
@property
def device(self):
return self.audio_tower.device
@property
def hidden_size(self):
return self.config.hidden_size