Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import functools | |
| import gzip | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import warnings | |
| from collections import defaultdict | |
| from itertools import islice | |
| from pathlib import Path | |
| from typing import ( | |
| Any, | |
| ClassVar, | |
| Dict, | |
| Iterable, | |
| List, | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| Type, | |
| TYPE_CHECKING, | |
| Union, | |
| ) | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from pytorch3d.implicitron.tools.config import registry, ReplaceableBase | |
| from pytorch3d.io import IO | |
| from pytorch3d.renderer.camera_utils import join_cameras_as_batch | |
| from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras | |
| from pytorch3d.structures.pointclouds import Pointclouds | |
| from tqdm import tqdm | |
| from pytorch3d.implicitron.dataset import types | |
| from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData | |
| from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar | |
| logger = logging.getLogger(__name__) | |
| if TYPE_CHECKING: | |
| from typing import TypedDict | |
| class FrameAnnotsEntry(TypedDict): | |
| subset: Optional[str] | |
| frame_annotation: types.FrameAnnotation | |
| else: | |
| FrameAnnotsEntry = dict | |
| class JsonIndexDataset(DatasetBase, ReplaceableBase): | |
| """ | |
| A dataset with annotations in json files like the Common Objects in 3D | |
| (CO3D) dataset. | |
| Args: | |
| frame_annotations_file: A zipped json file containing metadata of the | |
| frames in the dataset, serialized List[types.FrameAnnotation]. | |
| sequence_annotations_file: A zipped json file containing metadata of the | |
| sequences in the dataset, serialized List[types.SequenceAnnotation]. | |
| subset_lists_file: A json file containing the lists of frames corresponding | |
| corresponding to different subsets (e.g. train/val/test) of the dataset; | |
| format: {subset: (sequence_name, frame_id, file_path)}. | |
| subsets: Restrict frames/sequences only to the given list of subsets | |
| as defined in subset_lists_file (see above). | |
| limit_to: Limit the dataset to the first #limit_to frames (after other | |
| filters have been applied). | |
| limit_sequences_to: Limit the dataset to the first | |
| #limit_sequences_to sequences (after other sequence filters have been | |
| applied but before frame-based filters). | |
| pick_sequence: A list of sequence names to restrict the dataset to. | |
| exclude_sequence: A list of the names of the sequences to exclude. | |
| limit_category_to: Restrict the dataset to the given list of categories. | |
| dataset_root: The root folder of the dataset; all the paths in jsons are | |
| specified relative to this root (but not json paths themselves). | |
| load_images: Enable loading the frame RGB data. | |
| load_depths: Enable loading the frame depth maps. | |
| load_depth_masks: Enable loading the frame depth map masks denoting the | |
| depth values used for evaluation (the points consistent across views). | |
| load_masks: Enable loading frame foreground masks. | |
| load_point_clouds: Enable loading sequence-level point clouds. | |
| max_points: Cap on the number of loaded points in the point cloud; | |
| if reached, they are randomly sampled without replacement. | |
| mask_images: Whether to mask the images with the loaded foreground masks; | |
| 0 value is used for background. | |
| mask_depths: Whether to mask the depth maps with the loaded foreground | |
| masks; 0 value is used for background. | |
| image_height: The height of the returned images, masks, and depth maps; | |
| aspect ratio is preserved during cropping/resizing. | |
| image_width: The width of the returned images, masks, and depth maps; | |
| aspect ratio is preserved during cropping/resizing. | |
| box_crop: Enable cropping of the image around the bounding box inferred | |
| from the foreground region of the loaded segmentation mask; masks | |
| and depth maps are cropped accordingly; cameras are corrected. | |
| box_crop_mask_thr: The threshold used to separate pixels into foreground | |
| and background based on the foreground_probability mask; if no value | |
| is greater than this threshold, the loader lowers it and repeats. | |
| box_crop_context: The amount of additional padding added to each | |
| dimension of the cropping bounding box, relative to box size. | |
| remove_empty_masks: Removes the frames with no active foreground pixels | |
| in the segmentation mask after thresholding (see box_crop_mask_thr). | |
| n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence | |
| frames in each sequences uniformly without replacement if it has | |
| more frames than that; applied before other frame-level filters. | |
| seed: The seed of the random generator sampling #n_frames_per_sequence | |
| random frames per sequence. | |
| sort_frames: Enable frame annotations sorting to group frames from the | |
| same sequences together and order them by timestamps | |
| eval_batches: A list of batches that form the evaluation set; | |
| list of batch-sized lists of indices corresponding to __getitem__ | |
| of this class, thus it can be used directly as a batch sampler. | |
| eval_batch_index: | |
| ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) | |
| A list of batches of frames described as (sequence_name, frame_idx) | |
| that can form the evaluation set, `eval_batches` will be set from this. | |
| """ | |
| frame_annotations_type: ClassVar[ | |
| Type[types.FrameAnnotation] | |
| ] = types.FrameAnnotation | |
| path_manager: Any = None | |
| frame_annotations_file: str = "" | |
| sequence_annotations_file: str = "" | |
| subset_lists_file: str = "" | |
| subsets: Optional[List[str]] = None | |
| limit_to: int = 0 | |
| limit_sequences_to: int = 0 | |
| pick_sequence: Tuple[str, ...] = () | |
| exclude_sequence: Tuple[str, ...] = () | |
| limit_category_to: Tuple[int, ...] = () | |
| dataset_root: str = "" | |
| load_images: bool = True | |
| load_depths: bool = True | |
| load_depth_masks: bool = True | |
| load_masks: bool = True | |
| load_point_clouds: bool = False | |
| max_points: int = 0 | |
| mask_images: bool = False | |
| mask_depths: bool = False | |
| image_height: Optional[int] = 800 | |
| image_width: Optional[int] = 800 | |
| box_crop: bool = True | |
| box_crop_mask_thr: float = 0.4 | |
| box_crop_context: float = 0.3 | |
| remove_empty_masks: bool = True | |
| n_frames_per_sequence: int = -1 | |
| seed: int = 0 | |
| sort_frames: bool = False | |
| eval_batches: Any = None | |
| eval_batch_index: Any = None | |
| # frame_annots: List[FrameAnnotsEntry] = field(init=False) | |
| # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) | |
| def __post_init__(self) -> None: | |
| # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. | |
| self.subset_to_image_path = None | |
| self._load_frames() | |
| self._load_sequences() | |
| if self.sort_frames: | |
| self._sort_frames() | |
| self._load_subset_lists() | |
| self._filter_db() # also computes sequence indices | |
| self._extract_and_set_eval_batches() | |
| logger.info(str(self)) | |
| def _extract_and_set_eval_batches(self): | |
| """ | |
| Sets eval_batches based on input eval_batch_index. | |
| """ | |
| if self.eval_batch_index is not None: | |
| if self.eval_batches is not None: | |
| raise ValueError( | |
| "Cannot define both eval_batch_index and eval_batches." | |
| ) | |
| self.eval_batches = self.seq_frame_index_to_dataset_index( | |
| self.eval_batch_index | |
| ) | |
| def join(self, other_datasets: Iterable[DatasetBase]) -> None: | |
| """ | |
| Join the dataset with other JsonIndexDataset objects. | |
| Args: | |
| other_datasets: A list of JsonIndexDataset objects to be joined | |
| into the current dataset. | |
| """ | |
| if not all(isinstance(d, JsonIndexDataset) for d in other_datasets): | |
| raise ValueError("This function can only join a list of JsonIndexDataset") | |
| # pyre-ignore[16] | |
| self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots]) | |
| # pyre-ignore[16] | |
| self.seq_annots.update( | |
| # https://gist.github.com/treyhunner/f35292e676efa0be1728 | |
| functools.reduce( | |
| lambda a, b: {**a, **b}, | |
| [d.seq_annots for d in other_datasets], # pyre-ignore[16] | |
| ) | |
| ) | |
| all_eval_batches = [ | |
| self.eval_batches, | |
| # pyre-ignore | |
| *[d.eval_batches for d in other_datasets], | |
| ] | |
| if not ( | |
| all(ba is None for ba in all_eval_batches) | |
| or all(ba is not None for ba in all_eval_batches) | |
| ): | |
| raise ValueError( | |
| "When joining datasets, either all joined datasets have to have their" | |
| " eval_batches defined, or all should have their eval batches undefined." | |
| ) | |
| if self.eval_batches is not None: | |
| self.eval_batches = sum(all_eval_batches, []) | |
| self._invalidate_indexes(filter_seq_annots=True) | |
| def is_filtered(self) -> bool: | |
| """ | |
| Returns `True` in case the dataset has been filtered and thus some frame annotations | |
| stored on the disk might be missing in the dataset object. | |
| Returns: | |
| is_filtered: `True` if the dataset has been filtered, else `False`. | |
| """ | |
| return ( | |
| self.remove_empty_masks | |
| or self.limit_to > 0 | |
| or self.limit_sequences_to > 0 | |
| or len(self.pick_sequence) > 0 | |
| or len(self.exclude_sequence) > 0 | |
| or len(self.limit_category_to) > 0 | |
| or self.n_frames_per_sequence > 0 | |
| ) | |
| def seq_frame_index_to_dataset_index( | |
| self, | |
| seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], | |
| allow_missing_indices: bool = False, | |
| remove_missing_indices: bool = False, | |
| suppress_missing_index_warning: bool = True, | |
| ) -> List[List[Union[Optional[int], int]]]: | |
| """ | |
| Obtain indices into the dataset object given a list of frame ids. | |
| Args: | |
| seq_frame_index: The list of frame ids specified as | |
| `List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally, | |
| Image paths relative to the dataset_root can be stored specified as well: | |
| `List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]` | |
| allow_missing_indices: If `False`, throws an IndexError upon reaching the first | |
| entry from `seq_frame_index` which is missing in the dataset. | |
| Otherwise, depending on `remove_missing_indices`, either returns `None` | |
| in place of missing entries or removes the indices of missing entries. | |
| remove_missing_indices: Active when `allow_missing_indices=True`. | |
| If `False`, returns `None` in place of `seq_frame_index` entries that | |
| are not present in the dataset. | |
| If `True` removes missing indices from the returned indices. | |
| suppress_missing_index_warning: | |
| Active if `allow_missing_indices==True`. Suppressess a warning message | |
| in case an entry from `seq_frame_index` is missing in the dataset | |
| (expected in certain cases - e.g. when setting | |
| `self.remove_empty_masks=True`). | |
| Returns: | |
| dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`. | |
| """ | |
| _dataset_seq_frame_n_index = { | |
| seq: { | |
| # pyre-ignore[16] | |
| self.frame_annots[idx]["frame_annotation"].frame_number: idx | |
| for idx in seq_idx | |
| } | |
| # pyre-ignore[16] | |
| for seq, seq_idx in self._seq_to_idx.items() | |
| } | |
| def _get_dataset_idx( | |
| seq_name: str, frame_no: int, path: Optional[str] = None | |
| ) -> Optional[int]: | |
| idx_seq = _dataset_seq_frame_n_index.get(seq_name, None) | |
| idx = idx_seq.get(frame_no, None) if idx_seq is not None else None | |
| if idx is None: | |
| msg = ( | |
| f"sequence_name={seq_name} / frame_number={frame_no}" | |
| " not in the dataset!" | |
| ) | |
| if not allow_missing_indices: | |
| raise IndexError(msg) | |
| if not suppress_missing_index_warning: | |
| warnings.warn(msg) | |
| return idx | |
| if path is not None: | |
| # Check that the loaded frame path is consistent | |
| # with the one stored in self.frame_annots. | |
| assert os.path.normpath( | |
| # pyre-ignore[16] | |
| self.frame_annots[idx]["frame_annotation"].image.path | |
| ) == os.path.normpath( | |
| path | |
| ), f"Inconsistent frame indices {seq_name, frame_no, path}." | |
| return idx | |
| dataset_idx = [ | |
| [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6] | |
| for batch in seq_frame_index | |
| ] | |
| if allow_missing_indices and remove_missing_indices: | |
| # remove all None indices, and also batches with only None entries | |
| valid_dataset_idx = [ | |
| [b for b in batch if b is not None] for batch in dataset_idx | |
| ] | |
| return [ # pyre-ignore[7] | |
| batch for batch in valid_dataset_idx if len(batch) > 0 | |
| ] | |
| return dataset_idx | |
| def subset_from_frame_index( | |
| self, | |
| frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]], | |
| allow_missing_indices: bool = True, | |
| ) -> "JsonIndexDataset": | |
| """ | |
| Generate a dataset subset given the list of frames specified in `frame_index`. | |
| Args: | |
| frame_index: The list of frame indentifiers (as stored in the metadata) | |
| specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally, | |
| Image paths relative to the dataset_root can be stored specified as well: | |
| `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`, | |
| in the latter case, if imaga_path do not match the stored paths, an error | |
| is raised. | |
| allow_missing_indices: If `False`, throws an IndexError upon reaching the first | |
| entry from `frame_index` which is missing in the dataset. | |
| Otherwise, generates a subset consisting of frames entries that actually | |
| exist in the dataset. | |
| """ | |
| # Get the indices into the frame annots. | |
| dataset_indices = self.seq_frame_index_to_dataset_index( | |
| [frame_index], | |
| allow_missing_indices=self.is_filtered() and allow_missing_indices, | |
| )[0] | |
| valid_dataset_indices = [i for i in dataset_indices if i is not None] | |
| # Deep copy the whole dataset except frame_annots, which are large so we | |
| # deep copy only the requested subset of frame_annots. | |
| memo = {id(self.frame_annots): None} # pyre-ignore[16] | |
| dataset_new = copy.deepcopy(self, memo) | |
| dataset_new.frame_annots = copy.deepcopy( | |
| [self.frame_annots[i] for i in valid_dataset_indices] | |
| ) | |
| # This will kill all unneeded sequence annotations. | |
| dataset_new._invalidate_indexes(filter_seq_annots=True) | |
| # Finally annotate the frame annotations with the name of the subset | |
| # stored in meta. | |
| for frame_annot in dataset_new.frame_annots: | |
| frame_annotation = frame_annot["frame_annotation"] | |
| if frame_annotation.meta is not None: | |
| frame_annot["subset"] = frame_annotation.meta.get("frame_type", None) | |
| # A sanity check - this will crash in case some entries from frame_index are missing | |
| # in dataset_new. | |
| valid_frame_index = [ | |
| fi for fi, di in zip(frame_index, dataset_indices) if di is not None | |
| ] | |
| dataset_new.seq_frame_index_to_dataset_index( | |
| [valid_frame_index], allow_missing_indices=False | |
| ) | |
| return dataset_new | |
| def __str__(self) -> str: | |
| # pyre-ignore[16] | |
| return f"JsonIndexDataset #frames={len(self.frame_annots)}" | |
| def __len__(self) -> int: | |
| # pyre-ignore[16] | |
| return len(self.frame_annots) | |
| def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: | |
| return entry["subset"] | |
| def get_all_train_cameras(self) -> CamerasBase: | |
| """ | |
| Returns the cameras corresponding to all the known frames. | |
| """ | |
| logger.info("Loading all train cameras.") | |
| cameras = [] | |
| # pyre-ignore[16] | |
| for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)): | |
| frame_type = self._get_frame_type(frame_annot) | |
| if frame_type is None: | |
| raise ValueError("subsets not loaded") | |
| if is_known_frame_scalar(frame_type): | |
| cameras.append(self[frame_idx].camera) | |
| return join_cameras_as_batch(cameras) | |
| def __getitem__(self, index) -> FrameData: | |
| # pyre-ignore[16] | |
| if index >= len(self.frame_annots): | |
| raise IndexError(f"index {index} out of range {len(self.frame_annots)}") | |
| entry = self.frame_annots[index]["frame_annotation"] | |
| # pyre-ignore[16] | |
| point_cloud = self.seq_annots[entry.sequence_name].point_cloud | |
| frame_data = FrameData( | |
| frame_number=_safe_as_tensor(entry.frame_number, torch.long), | |
| frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), | |
| sequence_name=entry.sequence_name, | |
| sequence_category=self.seq_annots[entry.sequence_name].category, | |
| camera_quality_score=_safe_as_tensor( | |
| self.seq_annots[entry.sequence_name].viewpoint_quality_score, | |
| torch.float, | |
| ), | |
| point_cloud_quality_score=_safe_as_tensor( | |
| point_cloud.quality_score, torch.float | |
| ) | |
| if point_cloud is not None | |
| else None, | |
| ) | |
| # The rest of the fields are optional | |
| frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) | |
| ( | |
| frame_data.fg_probability, | |
| frame_data.mask_path, | |
| frame_data.bbox_xywh, | |
| clamp_bbox_xyxy, | |
| frame_data.crop_bbox_xywh, | |
| ) = self._load_crop_fg_probability(entry) | |
| scale = 1.0 | |
| if self.load_images and entry.image is not None: | |
| # original image size | |
| frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) | |
| ( | |
| frame_data.image_rgb, | |
| frame_data.image_path, | |
| frame_data.mask_crop, | |
| scale, | |
| ) = self._load_crop_images( | |
| entry, frame_data.fg_probability, clamp_bbox_xyxy | |
| ) | |
| if self.load_depths and entry.depth is not None: | |
| ( | |
| frame_data.depth_map, | |
| frame_data.depth_path, | |
| frame_data.depth_mask, | |
| ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) | |
| if entry.viewpoint is not None: | |
| frame_data.camera = self._get_pytorch3d_camera( | |
| entry, | |
| scale, | |
| clamp_bbox_xyxy, | |
| ) | |
| if self.load_point_clouds and point_cloud is not None: | |
| pcl_path = self._fix_point_cloud_path(point_cloud.path) | |
| frame_data.sequence_point_cloud = _load_pointcloud( | |
| self._local_path(pcl_path), max_points=self.max_points | |
| ) | |
| frame_data.sequence_point_cloud_path = pcl_path | |
| return frame_data | |
| def _fix_point_cloud_path(self, path: str) -> str: | |
| """ | |
| Fix up a point cloud path from the dataset. | |
| Some files in Co3Dv2 have an accidental absolute path stored. | |
| """ | |
| unwanted_prefix = ( | |
| "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" | |
| ) | |
| if path.startswith(unwanted_prefix): | |
| path = path[len(unwanted_prefix) :] | |
| return os.path.join(self.dataset_root, path) | |
| def _load_crop_fg_probability( | |
| self, entry: types.FrameAnnotation | |
| ) -> Tuple[ | |
| Optional[torch.Tensor], | |
| Optional[str], | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| Optional[torch.Tensor], | |
| ]: | |
| fg_probability = None | |
| full_path = None | |
| bbox_xywh = None | |
| clamp_bbox_xyxy = None | |
| crop_box_xywh = None | |
| if (self.load_masks or self.box_crop) and entry.mask is not None: | |
| full_path = os.path.join(self.dataset_root, entry.mask.path) | |
| mask = _load_mask(self._local_path(full_path)) | |
| if mask.shape[-2:] != entry.image.size: | |
| raise ValueError( | |
| f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" | |
| ) | |
| bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) | |
| if self.box_crop: | |
| clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( | |
| _get_clamp_bbox( | |
| bbox_xywh, | |
| image_path=entry.image.path, | |
| box_crop_context=self.box_crop_context, | |
| ), | |
| image_size_hw=tuple(mask.shape[-2:]), | |
| ) | |
| crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) | |
| mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) | |
| fg_probability, _, _ = self._resize_image(mask, mode="nearest") | |
| return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh | |
| def _load_crop_images( | |
| self, | |
| entry: types.FrameAnnotation, | |
| fg_probability: Optional[torch.Tensor], | |
| clamp_bbox_xyxy: Optional[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: | |
| assert self.dataset_root is not None and entry.image is not None | |
| path = os.path.join(self.dataset_root, entry.image.path) | |
| image_rgb = _load_image(self._local_path(path)) | |
| if image_rgb.shape[-2:] != entry.image.size: | |
| raise ValueError( | |
| f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" | |
| ) | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) | |
| image_rgb, scale, mask_crop = self._resize_image(image_rgb) | |
| if self.mask_images: | |
| assert fg_probability is not None | |
| image_rgb *= fg_probability | |
| return image_rgb, path, mask_crop, scale | |
| def _load_mask_depth( | |
| self, | |
| entry: types.FrameAnnotation, | |
| clamp_bbox_xyxy: Optional[torch.Tensor], | |
| fg_probability: Optional[torch.Tensor], | |
| ) -> Tuple[torch.Tensor, str, torch.Tensor]: | |
| entry_depth = entry.depth | |
| assert entry_depth is not None | |
| path = os.path.join(self.dataset_root, entry_depth.path) | |
| depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| depth_bbox_xyxy = _rescale_bbox( | |
| clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] | |
| ) | |
| depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) | |
| depth_map, _, _ = self._resize_image(depth_map, mode="nearest") | |
| if self.mask_depths: | |
| assert fg_probability is not None | |
| depth_map *= fg_probability | |
| if self.load_depth_masks: | |
| assert entry_depth.mask_path is not None | |
| mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) | |
| depth_mask = _load_depth_mask(self._local_path(mask_path)) | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| depth_mask_bbox_xyxy = _rescale_bbox( | |
| clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] | |
| ) | |
| depth_mask = _crop_around_box( | |
| depth_mask, depth_mask_bbox_xyxy, mask_path | |
| ) | |
| depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") | |
| else: | |
| depth_mask = torch.ones_like(depth_map) | |
| return depth_map, path, depth_mask | |
| def _get_pytorch3d_camera( | |
| self, | |
| entry: types.FrameAnnotation, | |
| scale: float, | |
| clamp_bbox_xyxy: Optional[torch.Tensor], | |
| ) -> PerspectiveCameras: | |
| entry_viewpoint = entry.viewpoint | |
| assert entry_viewpoint is not None | |
| # principal point and focal length | |
| principal_point = torch.tensor( | |
| entry_viewpoint.principal_point, dtype=torch.float | |
| ) | |
| focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) | |
| half_image_size_wh_orig = ( | |
| torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 | |
| ) | |
| # first, we convert from the dataset's NDC convention to pixels | |
| format = entry_viewpoint.intrinsics_format | |
| if format.lower() == "ndc_norm_image_bounds": | |
| # this is e.g. currently used in CO3D for storing intrinsics | |
| rescale = half_image_size_wh_orig | |
| elif format.lower() == "ndc_isotropic": | |
| rescale = half_image_size_wh_orig.min() | |
| else: | |
| raise ValueError(f"Unknown intrinsics format: {format}") | |
| # principal point and focal length in pixels | |
| principal_point_px = half_image_size_wh_orig - principal_point * rescale | |
| focal_length_px = focal_length * rescale | |
| if self.box_crop: | |
| assert clamp_bbox_xyxy is not None | |
| principal_point_px -= clamp_bbox_xyxy[:2] | |
| # now, convert from pixels to PyTorch3D v0.5+ NDC convention | |
| if self.image_height is None or self.image_width is None: | |
| out_size = list(reversed(entry.image.size)) | |
| else: | |
| out_size = [self.image_width, self.image_height] | |
| half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 | |
| half_min_image_size_output = half_image_size_output.min() | |
| # rescaled principal point and focal length in ndc | |
| principal_point = ( | |
| half_image_size_output - principal_point_px * scale | |
| ) / half_min_image_size_output | |
| focal_length = focal_length_px * scale / half_min_image_size_output | |
| return PerspectiveCameras( | |
| focal_length=focal_length[None], | |
| principal_point=principal_point[None], | |
| R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], | |
| T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], | |
| ) | |
| def _load_frames(self) -> None: | |
| logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.") | |
| local_file = self._local_path(self.frame_annotations_file) | |
| with gzip.open(local_file, "rt", encoding="utf8") as zipfile: | |
| frame_annots_list = types.load_dataclass( | |
| zipfile, List[self.frame_annotations_type] | |
| ) | |
| if not frame_annots_list: | |
| raise ValueError("Empty dataset!") | |
| # pyre-ignore[16] | |
| self.frame_annots = [ | |
| FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list | |
| ] | |
| def _load_sequences(self) -> None: | |
| logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.") | |
| local_file = self._local_path(self.sequence_annotations_file) | |
| with gzip.open(local_file, "rt", encoding="utf8") as zipfile: | |
| seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) | |
| if not seq_annots: | |
| raise ValueError("Empty sequences file!") | |
| # pyre-ignore[16] | |
| self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} | |
| def _load_subset_lists(self) -> None: | |
| logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") | |
| if not self.subset_lists_file: | |
| return | |
| with open(self._local_path(self.subset_lists_file), "r") as f: | |
| subset_to_seq_frame = json.load(f) | |
| frame_path_to_subset = { | |
| path: subset | |
| for subset, frames in subset_to_seq_frame.items() | |
| for _, _, path in frames | |
| } | |
| # pyre-ignore[16] | |
| for frame in self.frame_annots: | |
| frame["subset"] = frame_path_to_subset.get( | |
| frame["frame_annotation"].image.path, None | |
| ) | |
| if frame["subset"] is None: | |
| warnings.warn( | |
| "Subset lists are given but don't include " | |
| + frame["frame_annotation"].image.path | |
| ) | |
| def _sort_frames(self) -> None: | |
| # Sort frames to have them grouped by sequence, ordered by timestamp | |
| # pyre-ignore[16] | |
| self.frame_annots = sorted( | |
| self.frame_annots, | |
| key=lambda f: ( | |
| f["frame_annotation"].sequence_name, | |
| f["frame_annotation"].frame_timestamp or 0, | |
| ), | |
| ) | |
| def _filter_db(self) -> None: | |
| if self.remove_empty_masks: | |
| logger.info("Removing images with empty masks.") | |
| # pyre-ignore[16] | |
| old_len = len(self.frame_annots) | |
| msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." | |
| def positive_mass(frame_annot: types.FrameAnnotation) -> bool: | |
| mask = frame_annot.mask | |
| if mask is None: | |
| return False | |
| if mask.mass is None: | |
| raise ValueError(msg) | |
| return mask.mass > 1 | |
| self.frame_annots = [ | |
| frame | |
| for frame in self.frame_annots | |
| if positive_mass(frame["frame_annotation"]) | |
| ] | |
| logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) | |
| # this has to be called after joining with categories!! | |
| subsets = self.subsets | |
| if subsets: | |
| if not self.subset_lists_file: | |
| raise ValueError( | |
| "Subset filter is on but subset_lists_file was not given" | |
| ) | |
| logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") | |
| # truncate the list of subsets to the valid one | |
| self.frame_annots = [ | |
| entry for entry in self.frame_annots if entry["subset"] in subsets | |
| ] | |
| if len(self.frame_annots) == 0: | |
| raise ValueError(f"There are no frames in the '{subsets}' subsets!") | |
| self._invalidate_indexes(filter_seq_annots=True) | |
| if len(self.limit_category_to) > 0: | |
| logger.info(f"Limiting dataset to categories: {self.limit_category_to}") | |
| # pyre-ignore[16] | |
| self.seq_annots = { | |
| name: entry | |
| for name, entry in self.seq_annots.items() | |
| if entry.category in self.limit_category_to | |
| } | |
| # sequence filters | |
| for prefix in ("pick", "exclude"): | |
| orig_len = len(self.seq_annots) | |
| attr = f"{prefix}_sequence" | |
| arr = getattr(self, attr) | |
| if len(arr) > 0: | |
| logger.info(f"{attr}: {str(arr)}") | |
| self.seq_annots = { | |
| name: entry | |
| for name, entry in self.seq_annots.items() | |
| if (name in arr) == (prefix == "pick") | |
| } | |
| logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) | |
| if self.limit_sequences_to > 0: | |
| self.seq_annots = dict( | |
| islice(self.seq_annots.items(), self.limit_sequences_to) | |
| ) | |
| # retain only frames from retained sequences | |
| self.frame_annots = [ | |
| f | |
| for f in self.frame_annots | |
| if f["frame_annotation"].sequence_name in self.seq_annots | |
| ] | |
| self._invalidate_indexes() | |
| if self.n_frames_per_sequence > 0: | |
| logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") | |
| keep_idx = [] | |
| # pyre-ignore[16] | |
| for seq, seq_indices in self._seq_to_idx.items(): | |
| # infer the seed from the sequence name, this is reproducible | |
| # and makes the selection differ for different sequences | |
| seed = _seq_name_to_seed(seq) + self.seed | |
| seq_idx_shuffled = random.Random(seed).sample( | |
| sorted(seq_indices), len(seq_indices) | |
| ) | |
| keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) | |
| logger.info( | |
| "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) | |
| ) | |
| self.frame_annots = [self.frame_annots[i] for i in keep_idx] | |
| self._invalidate_indexes(filter_seq_annots=False) | |
| # sequences are not decimated, so self.seq_annots is valid | |
| if self.limit_to > 0 and self.limit_to < len(self.frame_annots): | |
| logger.info( | |
| "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) | |
| ) | |
| self.frame_annots = self.frame_annots[: self.limit_to] | |
| self._invalidate_indexes(filter_seq_annots=True) | |
| def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: | |
| # update _seq_to_idx and filter seq_meta according to frame_annots change | |
| # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx | |
| self._invalidate_seq_to_idx() | |
| if filter_seq_annots: | |
| # pyre-ignore[16] | |
| self.seq_annots = { | |
| k: v | |
| for k, v in self.seq_annots.items() | |
| # pyre-ignore[16] | |
| if k in self._seq_to_idx | |
| } | |
| def _invalidate_seq_to_idx(self) -> None: | |
| seq_to_idx = defaultdict(list) | |
| # pyre-ignore[16] | |
| for idx, entry in enumerate(self.frame_annots): | |
| seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) | |
| # pyre-ignore[16] | |
| self._seq_to_idx = seq_to_idx | |
| def _resize_image( | |
| self, image, mode="bilinear" | |
| ) -> Tuple[torch.Tensor, float, torch.Tensor]: | |
| image_height, image_width = self.image_height, self.image_width | |
| if image_height is None or image_width is None: | |
| # skip the resizing | |
| imre_ = torch.from_numpy(image) | |
| return imre_, 1.0, torch.ones_like(imre_[:1]) | |
| # takes numpy array, returns pytorch tensor | |
| minscale = min( | |
| image_height / image.shape[-2], | |
| image_width / image.shape[-1], | |
| ) | |
| imre = torch.nn.functional.interpolate( | |
| torch.from_numpy(image)[None], | |
| scale_factor=minscale, | |
| mode=mode, | |
| align_corners=False if mode == "bilinear" else None, | |
| recompute_scale_factor=True, | |
| )[0] | |
| # pyre-fixme[19]: Expected 1 positional argument. | |
| imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) | |
| imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre | |
| # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. | |
| # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. | |
| mask = torch.zeros(1, self.image_height, self.image_width) | |
| mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 | |
| return imre_, minscale, mask | |
| def _local_path(self, path: str) -> str: | |
| if self.path_manager is None: | |
| return path | |
| return self.path_manager.get_local_path(path) | |
| def get_frame_numbers_and_timestamps( | |
| self, idxs: Sequence[int] | |
| ) -> List[Tuple[int, float]]: | |
| out: List[Tuple[int, float]] = [] | |
| for idx in idxs: | |
| # pyre-ignore[16] | |
| frame_annotation = self.frame_annots[idx]["frame_annotation"] | |
| out.append( | |
| (frame_annotation.frame_number, frame_annotation.frame_timestamp) | |
| ) | |
| return out | |
| def category_to_sequence_names(self) -> Dict[str, List[str]]: | |
| c2seq = defaultdict(list) | |
| # pyre-ignore | |
| for sequence_name, sa in self.seq_annots.items(): | |
| c2seq[sa.category].append(sequence_name) | |
| return dict(c2seq) | |
| def get_eval_batches(self) -> Optional[List[List[int]]]: | |
| return self.eval_batches | |
| def _seq_name_to_seed(seq_name) -> int: | |
| return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16) | |
| def _load_image(path) -> np.ndarray: | |
| with Image.open(path) as pil_im: | |
| im = np.array(pil_im.convert("RGB")) | |
| im = im.transpose((2, 0, 1)) | |
| im = im.astype(np.float32) / 255.0 | |
| return im | |
| def _load_16big_png_depth(depth_png) -> np.ndarray: | |
| with Image.open(depth_png) as depth_pil: | |
| # the image is stored with 16-bit depth but PIL reads it as I (32 bit). | |
| # we cast it to uint16, then reinterpret as float16, then cast to float32 | |
| depth = ( | |
| np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) | |
| .astype(np.float32) | |
| .reshape((depth_pil.size[1], depth_pil.size[0])) | |
| ) | |
| return depth | |
| def _load_1bit_png_mask(file: str) -> np.ndarray: | |
| with Image.open(file) as pil_im: | |
| mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) | |
| return mask | |
| def _load_depth_mask(path: str) -> np.ndarray: | |
| if not path.lower().endswith(".png"): | |
| raise ValueError('unsupported depth mask file name "%s"' % path) | |
| m = _load_1bit_png_mask(path) | |
| return m[None] # fake feature channel | |
| def _load_depth(path, scale_adjustment) -> np.ndarray: | |
| if not path.lower().endswith(".png"): | |
| raise ValueError('unsupported depth file name "%s"' % path) | |
| d = _load_16big_png_depth(path) * scale_adjustment | |
| d[~np.isfinite(d)] = 0.0 | |
| return d[None] # fake feature channel | |
| def _load_mask(path) -> np.ndarray: | |
| with Image.open(path) as pil_im: | |
| mask = np.array(pil_im) | |
| mask = mask.astype(np.float32) / 255.0 | |
| return mask[None] # fake feature channel | |
| def _get_1d_bounds(arr) -> Tuple[int, int]: | |
| nz = np.flatnonzero(arr) | |
| return nz[0], nz[-1] + 1 | |
| def _get_bbox_from_mask( | |
| mask, thr, decrease_quant: float = 0.05 | |
| ) -> Tuple[int, int, int, int]: | |
| # bbox in xywh | |
| masks_for_box = np.zeros_like(mask) | |
| while masks_for_box.sum() <= 1.0: | |
| masks_for_box = (mask > thr).astype(np.float32) | |
| thr -= decrease_quant | |
| if thr <= 0.0: | |
| warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") | |
| x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) | |
| y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) | |
| return x0, y0, x1 - x0, y1 - y0 | |
| def _get_clamp_bbox( | |
| bbox: torch.Tensor, | |
| box_crop_context: float = 0.0, | |
| image_path: str = "", | |
| ) -> torch.Tensor: | |
| # box_crop_context: rate of expansion for bbox | |
| # returns possibly expanded bbox xyxy as float | |
| bbox = bbox.clone() # do not edit bbox in place | |
| # increase box size | |
| if box_crop_context > 0.0: | |
| c = box_crop_context | |
| bbox = bbox.float() | |
| bbox[0] -= bbox[2] * c / 2 | |
| bbox[1] -= bbox[3] * c / 2 | |
| bbox[2] += bbox[2] * c | |
| bbox[3] += bbox[3] * c | |
| if (bbox[2:] <= 1.0).any(): | |
| raise ValueError( | |
| f"squashed image {image_path}!! The bounding box contains no pixels." | |
| ) | |
| bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes | |
| bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) | |
| return bbox_xyxy | |
| def _crop_around_box(tensor, bbox, impath: str = ""): | |
| # bbox is xyxy, where the upper bound is corrected with +1 | |
| bbox = _clamp_box_to_image_bounds_and_round( | |
| bbox, | |
| image_size_hw=tensor.shape[-2:], | |
| ) | |
| tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] | |
| assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" | |
| return tensor | |
| def _clamp_box_to_image_bounds_and_round( | |
| bbox_xyxy: torch.Tensor, | |
| image_size_hw: Tuple[int, int], | |
| ) -> torch.LongTensor: | |
| bbox_xyxy = bbox_xyxy.clone() | |
| bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) | |
| bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) | |
| if not isinstance(bbox_xyxy, torch.LongTensor): | |
| bbox_xyxy = bbox_xyxy.round().long() | |
| return bbox_xyxy # pyre-ignore [7] | |
| def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: | |
| assert bbox is not None | |
| assert np.prod(orig_res) > 1e-8 | |
| # average ratio of dimensions | |
| rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 | |
| return bbox * rel_size | |
| def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: | |
| wh = xyxy[2:] - xyxy[:2] | |
| xywh = torch.cat([xyxy[:2], wh]) | |
| return xywh | |
| def _bbox_xywh_to_xyxy( | |
| xywh: torch.Tensor, clamp_size: Optional[int] = None | |
| ) -> torch.Tensor: | |
| xyxy = xywh.clone() | |
| if clamp_size is not None: | |
| xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) | |
| xyxy[2:] += xyxy[:2] | |
| return xyxy | |
| def _safe_as_tensor(data, dtype): | |
| if data is None: | |
| return None | |
| return torch.tensor(data, dtype=dtype) | |
| # NOTE this cache is per-worker; they are implemented as processes. | |
| # each batch is loaded and collated by a single worker; | |
| # since sequences tend to co-occur within batches, this is useful. | |
| def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: | |
| pcl = IO().load_pointcloud(pcl_path) | |
| if max_points > 0: | |
| pcl = pcl.subsample(max_points) | |
| return pcl |