Spaces:
Runtime error
Runtime error
| import ffmpegio | |
| import gc | |
| import torch | |
| from transformers import MobileViTImageProcessor, MobileViTForSemanticSegmentation | |
| from config import FPS_DIV, MAX_LENGTH, BATCH_SIZE, MODEL_PATH | |
| class PreprocessModel(torch.nn.Module): | |
| device = 'cpu' | |
| def __init__(self): | |
| super().__init__() | |
| self.feature_extractor = MobileViTImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small") | |
| self.mobile_vit = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small") | |
| self.convs = torch.nn.Sequential( | |
| torch.nn.MaxPool2d(2, 2) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.mobile_vit(x).logits | |
| x = self.convs(x) | |
| return x | |
| def read_video(self, path: str) -> torch.Tensor: | |
| """ | |
| Читает видео и возвращает тензор с фичами | |
| """ | |
| _, video = ffmpegio.video.read(path, t=1.0) | |
| video = video[::FPS_DIV][:MAX_LENGTH] | |
| out_seg_video = [] | |
| for i in range(0, video.shape[0], BATCH_SIZE): | |
| frames = [video[j] for j in range(i, min(i + BATCH_SIZE, video.shape[0]))] | |
| frames = self.feature_extractor(images=frames, return_tensors='pt')['pixel_values'] | |
| out = self.forward(frames.to(self.device)).detach().to('cpu') | |
| out_seg_video.append(out) | |
| del frames, out | |
| gc.collect() | |
| if self.device == 'cuda': | |
| torch.cuda.empty_cache() | |
| return torch.cat(out_seg_video) | |
| class VideoModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| p = 0.5 | |
| self.pic_cnn = torch.nn.Sequential( | |
| torch.nn.Conv2d(21, 128, (2, 2), stride=2), | |
| torch.nn.BatchNorm2d(128), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Conv2d(128, 256, (2, 2), stride=2), | |
| torch.nn.BatchNorm2d(256), | |
| torch.nn.Dropout2d(p), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Conv2d(256, 256, (4, 4), stride=2), | |
| torch.nn.BatchNorm2d(256), | |
| torch.nn.Dropout2d(p), | |
| torch.nn.Flatten() | |
| ) | |
| self.vid_cnn = torch.nn.Sequential( | |
| torch.nn.Conv2d(21, 128, (2, 2), stride=2), | |
| torch.nn.BatchNorm2d(128), | |
| torch.nn.Tanh(), | |
| torch.nn.Conv2d(128, 256, (2, 2), stride=2), | |
| torch.nn.BatchNorm2d(256), | |
| torch.nn.Dropout2d(p), | |
| torch.nn.LeakyReLU(), | |
| torch.nn.Conv2d(256, 512, (2, 2), stride=2), | |
| torch.nn.BatchNorm2d(512), | |
| torch.nn.Dropout2d(p), | |
| torch.nn.Flatten() | |
| ) | |
| self.lstm = torch.nn.LSTM(2048, 256, 1, batch_first=True, bidirectional=True) | |
| self.fc1 = torch.nn.Linear(256 * 2, 1024) | |
| self.fc_norm = torch.nn.BatchNorm1d(256 * 2) | |
| self.tanh = torch.nn.Tanh() | |
| self.fc2 = torch.nn.Linear(1024, 2) | |
| self.sigmoid = torch.nn.Sigmoid() | |
| self.dropout = torch.nn.Dropout(p) | |
| # xaiver init | |
| for m in self.modules(): | |
| if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv3d): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| elif isinstance(m, torch.nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| def forward(self, video: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Использует превью как начальное скрытое состояние, а кадры видео как последовательность. | |
| video[0] - превью, video[1] - видео | |
| :param video: torch.Tensor, shape = (batch_size, frames + 1, 1344) | |
| """ | |
| frames = video.shape[0] | |
| video = torch.nn.functional.pad(video, (0, 0, 0, 0, 0, 0, MAX_LENGTH + 1 - frames, 0)) | |
| video = video.unsqueeze(0) | |
| _batch_size = video.shape[0] | |
| _preview = video[:, 0, :, :] | |
| _video = video[:, 1:, :, :] | |
| h0 = self.pic_cnn(_preview).unsqueeze(0) | |
| h0 = torch.nn.functional.pad(h0, (0, 0, 0, 0, 0, 1)) | |
| c0 = torch.zeros_like(h0) | |
| _video = self.vid_cnn(_video.reshape(-1, 21, 16, 16)) | |
| _video = _video.reshape(_batch_size, 90, -1) | |
| context, _ = self.lstm(_video, (h0, c0)) | |
| out = self.fc_norm(context[:, -1]) | |
| out = self.tanh(self.fc1(out)) | |
| out = self.dropout(out) | |
| out = self.sigmoid(self.fc2(out)) | |
| return out | |
| # @st.cache_resource | |
| class TikTokAnalytics(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.preprocessing_model = PreprocessModel() | |
| self.predict_model = torch.load(MODEL_PATH, map_location=self.preprocessing_model.device) | |
| self.preprocessing_model.eval() | |
| self.predict_model.eval() | |
| def forward(self, path: str) -> torch.Tensor: | |
| """ | |
| Вызываем препроцесс, потом предикт | |
| :param path: | |
| :return: | |
| """ | |
| tensor = self.preprocessing_model.read_video(path) | |
| predict = self.predict_model(tensor) | |
| return predict | |
| # if __name__ == '__main__': | |
| # model = TikTokAnalytics() | |
| # model = model( | |
| # '/Users/victorbarbarich/PycharmProjects/nueramic/vktrbr-video-tiktok/data/videos/video-6930454291186502917.mp4') | |
| # print(model) | |