Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPImageProcessor | |
| try: | |
| from imagebind.models import imagebind_model | |
| from imagebind.models.imagebind_model import ModalityType | |
| from imagebind.data import load_and_transform_audio_data | |
| except ImportError: | |
| pass | |
| class ImageBindWrapper(nn.Module): | |
| def __init__(self, vision_tower, select_layer, select_feature="patch", delay_load=False): | |
| super().__init__() | |
| self.is_loaded = False | |
| self.vision_tower_name = vision_tower | |
| self.select_layer = select_layer | |
| self.select_feature = select_feature | |
| if not delay_load: | |
| self.load_model() | |
| def load_model(self): | |
| self.image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| self.vision_tower = imagebind_model.imagebind_huge(pretrained=True) | |
| for p in self.vision_tower.parameters(): | |
| p.requires_grad = False | |
| self.vision_tower.eval() | |
| self.is_loaded = True | |
| def train(self, mode=True): | |
| self.training = mode | |
| if self.is_loaded: | |
| self.vision_tower.eval() | |
| def forward(self, x): | |
| if type(x) == dict: | |
| if x["audios"] is not None: | |
| inputs = {ModalityType.AUDIO: load_and_transform_audio_data(x["audios"], device=self.device).half()} | |
| embeddings = self.vision_tower(inputs) | |
| audio_embedding = embeddings[ModalityType.AUDIO] | |
| return audio_embedding.unsqueeze(1) | |
| else: | |
| inputs = {ModalityType.VISION: x.to(dtype=self.dtype)} | |
| embeddings = self.vision_tower(inputs) | |
| vision_embedding = embeddings[ModalityType.VISION] | |
| if vision_embedding.ndim == 2: | |
| return vision_embedding.unsqueeze(1) | |
| if vision_embedding.shape[1] == 257: | |
| return vision_embedding[:, 1:] | |
| raise ValueError(f"Unexpected shape: {vision_embedding.shape}") | |
| def dummy_feature(self): | |
| return torch.zeros(1, 1024, device=self.device, dtype=self.dtype) | |
| def dtype(self): | |
| return self.vision_tower.modality_preprocessors.vision.cls_token.dtype | |
| def device(self): | |
| return self.vision_tower.modality_preprocessors.vision.cls_token.device | |
| def hidden_size(self): | |
| return 1024 | |