Spaces:
Runtime error
Runtime error
| from utils.dataset_utils import * | |
| class ImageDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer = None, | |
| width: int = 256, | |
| height: int = 256, | |
| base_width: int = 256, | |
| base_height: int = 256, | |
| use_caption: bool = False, | |
| image_dir: str = '', | |
| single_img_prompt: str = '', | |
| use_bucketing: bool = False, | |
| fallback_prompt: str = '', | |
| **kwargs | |
| ): | |
| self.tokenizer = tokenizer | |
| self.img_types = (".png", ".jpg", ".jpeg", '.bmp') | |
| self.use_bucketing = use_bucketing | |
| self.image_dir = self.get_images_list(image_dir) | |
| self.fallback_prompt = fallback_prompt | |
| self.use_caption = use_caption | |
| self.single_img_prompt = single_img_prompt | |
| self.width = width | |
| self.height = height | |
| def get_images_list(self, image_dir): | |
| if os.path.exists(image_dir): | |
| imgs = [x for x in os.listdir(image_dir) if x.endswith(self.img_types)] | |
| full_img_dir = [] | |
| for img in imgs: | |
| full_img_dir.append(f"{image_dir}/{img}") | |
| return sorted(full_img_dir) | |
| return [''] | |
| def image_batch(self, index): | |
| train_data = self.image_dir[index] | |
| img = train_data | |
| try: | |
| img = torchvision.io.read_image(img, mode=torchvision.io.ImageReadMode.RGB) | |
| except: | |
| img = T.transforms.PILToTensor()(Image.open(img).convert("RGB")) | |
| width = self.width | |
| height = self.height | |
| if self.use_bucketing: | |
| _, h, w = img.shape | |
| width, height = sensible_buckets(width, height, w, h) | |
| resize = T.transforms.Resize((height, width), antialias=True) | |
| img = resize(img) | |
| img = repeat(img, 'c h w -> f c h w', f=16) | |
| prompt = get_text_prompt( | |
| file_path=train_data, | |
| text_prompt=self.single_img_prompt, | |
| fallback_prompt=self.fallback_prompt, | |
| ext_types=self.img_types, | |
| use_caption=True | |
| ) | |
| prompt_ids = get_prompt_ids(prompt, self.tokenizer) | |
| return img, prompt, prompt_ids | |
| def __getname__(): return 'image' | |
| def __len__(self): | |
| # Image directory | |
| if os.path.exists(self.image_dir[0]): | |
| return len(self.image_dir) | |
| else: | |
| return 0 | |
| def __getitem__(self, index): | |
| img, prompt, prompt_ids = self.image_batch(index) | |
| example = { | |
| "pixel_values": (img / 127.5 - 1.0), | |
| "prompt_ids": prompt_ids[0], | |
| "text_prompt": prompt, | |
| 'dataset': self.__getname__() | |
| } | |
| return example |