|
|
""" |
|
|
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py |
|
|
""" |
|
|
import random |
|
|
import io |
|
|
import os |
|
|
import av |
|
|
import cv2 |
|
|
import decord |
|
|
import imageio |
|
|
from decord import VideoReader |
|
|
|
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import math |
|
|
|
|
|
decord.bridge.set_bridge("torch") |
|
|
|
|
|
import logging |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: |
|
|
""" |
|
|
Converts a present time with the given time base and start_pts offset to seconds. |
|
|
|
|
|
Returns: |
|
|
time_in_seconds (float): The corresponding time in seconds. |
|
|
|
|
|
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64 |
|
|
""" |
|
|
if pts == math.inf: |
|
|
return math.inf |
|
|
|
|
|
return int(pts - start_pts) * time_base |
|
|
|
|
|
|
|
|
def get_pyav_video_duration(video_reader): |
|
|
video_stream = video_reader.streams.video[0] |
|
|
video_duration = pts_to_secs( |
|
|
video_stream.duration, |
|
|
video_stream.time_base, |
|
|
video_stream.start_time |
|
|
) |
|
|
return float(video_duration) |
|
|
|
|
|
|
|
|
def get_frame_indices_by_fps(): |
|
|
pass |
|
|
|
|
|
|
|
|
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): |
|
|
if sample in ["rand", "middle"]: |
|
|
acc_samples = min(num_frames, vlen) |
|
|
|
|
|
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) |
|
|
ranges = [] |
|
|
for idx, interv in enumerate(intervals[:-1]): |
|
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
|
if sample == 'rand': |
|
|
try: |
|
|
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] |
|
|
except: |
|
|
frame_indices = np.random.permutation(vlen)[:acc_samples] |
|
|
frame_indices.sort() |
|
|
frame_indices = list(frame_indices) |
|
|
elif fix_start is not None: |
|
|
frame_indices = [x[0] + fix_start for x in ranges] |
|
|
elif sample == 'middle': |
|
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
if len(frame_indices) < num_frames: |
|
|
padded_frame_indices = [frame_indices[-1]] * num_frames |
|
|
padded_frame_indices[:len(frame_indices)] = frame_indices |
|
|
frame_indices = padded_frame_indices |
|
|
elif "fps" in sample: |
|
|
output_fps = float(sample[3:]) |
|
|
duration = float(vlen) / input_fps |
|
|
delta = 1 / output_fps |
|
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
|
|
frame_indices = np.around(frame_seconds * input_fps).astype(int) |
|
|
frame_indices = [e for e in frame_indices if e < vlen] |
|
|
if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
|
|
frame_indices = frame_indices[:max_num_frames] |
|
|
|
|
|
else: |
|
|
raise ValueError |
|
|
return frame_indices |
|
|
|
|
|
|
|
|
def read_frames_av( |
|
|
video_path, num_frames, sample='rand', fix_start=None, |
|
|
max_num_frames=-1, client=None, clip=None, |
|
|
): |
|
|
reader = av.open(video_path) |
|
|
frames = [torch.from_numpy(f.to_rgb().to_ndarray()) for f in reader.decode(video=0)] |
|
|
vlen = len(frames) |
|
|
duration = get_pyav_video_duration(reader) |
|
|
fps = vlen / float(duration) |
|
|
frame_indices = get_frame_indices( |
|
|
num_frames, vlen, sample=sample, fix_start=fix_start, |
|
|
input_fps=fps, max_num_frames=max_num_frames |
|
|
) |
|
|
frames = torch.stack([frames[idx] for idx in frame_indices]) |
|
|
frames = frames.permute(0, 3, 1, 2) |
|
|
return frames, frame_indices, fps |
|
|
|
|
|
|
|
|
def read_frames_gif( |
|
|
video_path, num_frames, sample='rand', fix_start=None, |
|
|
max_num_frames=-1, client=None, clip=None, |
|
|
): |
|
|
if video_path.startswith('s3') or video_path.startswith('p2'): |
|
|
video_bytes = client.get(video_path) |
|
|
gif = imageio.get_reader(io.BytesIO(video_bytes)) |
|
|
else: |
|
|
gif = imageio.get_reader(video_path) |
|
|
vlen = len(gif) |
|
|
frame_indices = get_frame_indices( |
|
|
num_frames, vlen, sample=sample, fix_start=fix_start, |
|
|
max_num_frames=max_num_frames |
|
|
) |
|
|
frames = [] |
|
|
for index, frame in enumerate(gif): |
|
|
|
|
|
if index in frame_indices: |
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB) |
|
|
frame = torch.from_numpy(frame).byte() |
|
|
|
|
|
frame = frame.permute(2, 0, 1) |
|
|
frames.append(frame) |
|
|
frames = torch.stack(frames) |
|
|
|
|
|
return frames, frame_indices, 25. |
|
|
|
|
|
|
|
|
def read_frames_hdfs(ind_file, vid, num_frames, sample='rand',fix_start=None, |
|
|
max_num_frames=-1, client=None, clip=None): |
|
|
_context_features = {'title': tf.io.FixedLenFeature([], dtype=tf.string)} |
|
|
_sequence_features = {'data': tf.io.FixedLenSequenceFeature([], dtype=tf.string)} |
|
|
num_parallel_reader = 1 |
|
|
filename, extension = os.path.splitext(ind_file) |
|
|
reader = KVReader(filename, num_parallel_reader) |
|
|
key = vid |
|
|
values = reader.read_many([key]) |
|
|
item = values[0] |
|
|
contexts, sequences = tf.io.parse_single_sequence_example( |
|
|
serialized=item, |
|
|
context_features=_context_features, |
|
|
sequence_features=_sequence_features) |
|
|
|
|
|
|
|
|
rawframes = sequences['data'] |
|
|
vlen = len(rawframes) |
|
|
sample="rand" |
|
|
|
|
|
frame_indices = get_frame_indices(num_frames, vlen, sample=sample, |
|
|
fix_start=fix_start, |
|
|
max_num_frames=max_num_frames) |
|
|
def read_image(raw_data): |
|
|
return tf.image.decode_jpeg(raw_data, channels=3, dct_method='INTEGER_ACCURATE').numpy() |
|
|
|
|
|
frames = [] |
|
|
for index, frame in enumerate(rawframes): |
|
|
if index in frame_indices: |
|
|
frame = read_image(frame) |
|
|
frame = torch.as_tensor(frame) |
|
|
frames.append(frame) |
|
|
|
|
|
frames = torch.stack(frames) |
|
|
|
|
|
frames = frames.permute(0, 3, 1, 2) |
|
|
return frames, frame_indices, 25 |
|
|
|
|
|
|
|
|
def read_frames_decord( |
|
|
video_path, num_frames, sample='rand', fix_start=None, |
|
|
max_num_frames=-1, client=None, clip=None |
|
|
): |
|
|
if video_path.startswith('s3') or video_path.startswith('p2'): |
|
|
video_bytes = client.get(video_path) |
|
|
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) |
|
|
else: |
|
|
video_reader = VideoReader(video_path, num_threads=1) |
|
|
vlen = len(video_reader) |
|
|
fps = video_reader.get_avg_fps() |
|
|
duration = vlen / float(fps) |
|
|
|
|
|
if clip: |
|
|
start, end = clip |
|
|
duration = end - start |
|
|
vlen = int(duration * fps) |
|
|
start_index = int(start * fps) |
|
|
|
|
|
frame_indices = get_frame_indices( |
|
|
num_frames, vlen, sample=sample, fix_start=fix_start, |
|
|
input_fps=fps, max_num_frames=max_num_frames |
|
|
) |
|
|
if clip: |
|
|
frame_indices = [f + start_index for f in frame_indices] |
|
|
|
|
|
frames = video_reader.get_batch(frame_indices) |
|
|
frames = frames.permute(0, 3, 1, 2) |
|
|
return frames, frame_indices, float(fps) |
|
|
|
|
|
|
|
|
VIDEO_READER_FUNCS = { |
|
|
'av': read_frames_av, |
|
|
'decord': read_frames_decord, |
|
|
'gif': read_frames_gif, |
|
|
'hdfs': read_frames_hdfs, |
|
|
} |
|
|
|