Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer | |
| import os | |
| # Models that use mean pooling | |
| POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"} | |
| #Mean Pooling - Take attention mask into account for correct averaging | |
| def mean_pooling(model_output, attention_mask): | |
| token_embeddings = model_output[0] #First element of model_output contains all token embeddings | |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
| class LanguageModel(nn.Module): | |
| def __init__(self, model='distilbert-base-uncased'): | |
| super(LanguageModel, self).__init__() | |
| self.tokenizer = AutoTokenizer.from_pretrained(model) | |
| self.model = AutoModel.from_pretrained(model) | |
| self.model_name = model | |
| # Remove the CLIP vision tower | |
| if "clip" in self.model_name: | |
| self.model.vision_model = None | |
| # Freeze the pre-trained parameters (very important) | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| # Make sure to set evaluation mode (also important) | |
| self.model.eval() | |
| def forward(self, text_batch): | |
| inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): # Ensure no gradients are computed for this forward pass | |
| if "clip" in self.model_name: | |
| sentence_embedding = self.model.get_text_features(**inputs) | |
| return sentence_embedding | |
| outputs = self.model(**inputs) | |
| if any(model in self.model_name for model in POOL_MODELS): | |
| sentence_embeddings = mean_pooling(outputs, inputs['attention_mask']) | |
| # Normalize embeddings | |
| sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1) | |
| else: | |
| sentence_embedding = outputs.last_hidden_state[:, 0, :] | |
| return sentence_embedding | |
| class LMHead(nn.Module): | |
| def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4): | |
| super(LMHead, self).__init__() | |
| self.fc1 = nn.Linear(embedding_dim, hidden_dim) | |
| #self.gelu = nn.GELU() | |
| self.fc2 = nn.Linear(hidden_dim, num_classes) | |
| def forward(self, x): | |
| embd = self.fc1(x) | |
| embd = F.normalize(embd, p=2, dim=1) | |
| deg_pred = self.fc2(embd) | |
| return embd, deg_pred |