import json import urllib.request import torchvision.transforms as T from cods.classif.data import ClassificationDataset from PIL import Image class DatasetWrapper(ClassificationDataset): def __init__(self, dataset, transforms=None, **kwargs): self.dataset = dataset self.root = "./data" path = self.root self.image_ids = [] if transforms is None: transforms = T.Compose( [ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ], ) self.transforms = transforms tmp = json.loads( urllib.request.urlopen( "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json", ).read(), ) wdnids = {int(k): v[0] for k, v in tmp.items()} self.wdnids = wdnids idx_to_cls = {int(k): v[1] for k, v in tmp.items()} # super().__init__( # path=path, # transforms=transforms, # idx_to_cls=idx_to_cls, # **kwargs, # ) self.idx_to_cls = idx_to_cls def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] # print(item) img, label = item["image"], item["label"] if img.mode != "RGB": img = img.convert("RGB") # img = Image.open(path) if self.transforms: img = self.transforms(img) return idx, img, label