|
|
import logging |
|
|
import os |
|
|
import json |
|
|
import random |
|
|
from torch.utils.data import Dataset |
|
|
import time |
|
|
from dataset.utils import load_image_from_path |
|
|
|
|
|
try: |
|
|
from petrel_client.client import Client |
|
|
has_client = True |
|
|
except ImportError: |
|
|
has_client = False |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ImageVideoBaseDataset(Dataset): |
|
|
"""Base class that implements the image and video loading methods""" |
|
|
|
|
|
media_type = "video" |
|
|
|
|
|
def __init__(self): |
|
|
assert self.media_type in ["image", "video", "only_video"] |
|
|
self.data_root = None |
|
|
self.anno_list = ( |
|
|
None |
|
|
) |
|
|
self.transform = None |
|
|
self.video_reader = None |
|
|
self.num_tries = None |
|
|
|
|
|
self.client = None |
|
|
if has_client: |
|
|
self.client = Client('~/petreloss.conf') |
|
|
|
|
|
def __getitem__(self, index): |
|
|
raise NotImplementedError |
|
|
|
|
|
def __len__(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_anno(self, index): |
|
|
"""obtain the annotation for one media (video or image) |
|
|
|
|
|
Args: |
|
|
index (int): The media index. |
|
|
|
|
|
Returns: dict. |
|
|
- "image": the filename, video also use "image". |
|
|
- "caption": The caption for this file. |
|
|
|
|
|
""" |
|
|
anno = self.anno_list[index] |
|
|
if self.data_root is not None: |
|
|
anno["image"] = os.path.join(self.data_root, anno["image"]) |
|
|
return anno |
|
|
|
|
|
def load_and_transform_media_data(self, index, data_path): |
|
|
if self.media_type == "image": |
|
|
return self.load_and_transform_media_data_image(index, data_path, clip_transform=self.clip_transform) |
|
|
else: |
|
|
return self.load_and_transform_media_data_video(index, data_path, clip_transform=self.clip_transform) |
|
|
|
|
|
def load_and_transform_media_data_image(self, index, data_path, clip_transform=False): |
|
|
image = load_image_from_path(data_path, client=self.client) |
|
|
if not clip_transform: |
|
|
image = self.transform(image) |
|
|
return image, index |
|
|
|
|
|
def load_and_transform_media_data_video(self, index, data_path, return_fps=False, clip=None, clip_transform=False): |
|
|
for _ in range(self.num_tries): |
|
|
try: |
|
|
max_num_frames = self.max_num_frames if hasattr(self, "max_num_frames") else -1 |
|
|
if "webvid" in data_path: |
|
|
hdfs_dir="hdfs://harunava/home/byte_ailab_us_cvg/user/weimin.wang/videogen_data/webvid_data/10M_full_train" |
|
|
video_name = os.path.basename(data_path) |
|
|
video_id, extension = os.path.splitext(video_name) |
|
|
ind_file = os.path.join(hdfs_dir, self.keys_indexfile[video_id]) |
|
|
frames, frame_indices, fps = self.video_reader(ind_file, video_id, self.num_frames, self.sample_type, |
|
|
max_num_frames=max_num_frames, client=self.client, clip=clip) |
|
|
else: |
|
|
frames, frame_indices, fps = self.video_reader( |
|
|
data_path, self.num_frames, self.sample_type, |
|
|
max_num_frames=max_num_frames, client=self.client, clip=clip |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning( |
|
|
f"Caught exception {e} when loading video {data_path}, " |
|
|
f"randomly sample a new video as replacement" |
|
|
) |
|
|
index = random.randint(0, len(self) - 1) |
|
|
ann = self.get_anno(index) |
|
|
data_path = ann["image"] |
|
|
continue |
|
|
|
|
|
if not clip_transform: |
|
|
frames = self.transform(frames) |
|
|
if return_fps: |
|
|
sec = [str(round(f / fps, 1)) for f in frame_indices] |
|
|
return frames, index, sec |
|
|
else: |
|
|
return frames, index |
|
|
else: |
|
|
raise RuntimeError( |
|
|
f"Failed to fetch video after {self.num_tries} tries. " |
|
|
f"This might indicate that you have many corrupted videos." |
|
|
) |
|
|
|