| import collections | |
| import os | |
| import random | |
| import torch | |
| from torch.utils.data import IterableDataset, DataLoader | |
| import pandas as pd | |
| import glob | |
| from typing import List, Dict, Any, Optional, Iterator | |
| import pyarrow.parquet as pq | |
| from transformers import AutoTokenizer | |
| from torchvision import transforms | |
| import json | |
| from PIL import Image | |
| class RefinedWebDataset(IterableDataset): | |
| def __init__(self, | |
| data_path, | |
| rank: int = 0, | |
| world_size: int = 1, | |
| shuffle=True, | |
| repeat=True, | |
| buffer_size=1000, | |
| max_length=8000, | |
| num_workers=1): | |
| super().__init__() | |
| self.files = sorted(glob.glob(data_path)) | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.shuffle = shuffle | |
| self.repeat = repeat | |
| self.buffer_size = buffer_size | |
| self.max_length = max_length | |
| self.num_workers = num_workers | |
| self.files = self.files[self.rank::self.world_size] | |
| def read_parquet_file(self, file_path): | |
| table = pq.read_table(file_path, columns=["content"]) | |
| df = table.to_pandas() | |
| for _, row in df.iterrows(): | |
| yield {"content": row["content"]} | |
| def __iter__(self): | |
| while True: | |
| file_list = self.files | |
| if self.shuffle: | |
| random.shuffle(file_list) | |
| for file in file_list: | |
| data_generator = self.read_parquet_file(file) | |
| buffer = [] | |
| for data in data_generator: | |
| text = data["content"].replace("\n", "") | |
| if len(text) > self.max_length: | |
| start_index = random.randint(0, len(text) - self.max_length - 1) | |
| selected_text = text[start_index:start_index + self.max_length] | |
| else: | |
| selected_text = text | |
| buffer.append({"input_ids": selected_text}) | |
| if len(buffer) >= self.buffer_size: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for item in buffer: | |
| yield item | |
| buffer = [] | |
| if buffer: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for item in buffer: | |
| yield item | |
| if not self.repeat: | |
| break | |
| def collate_fn(self, batch): | |
| batched = collections.defaultdict(list) | |
| for data in batch: | |
| for k, v in data.items(): | |
| batched[k].append(v) | |
| for k, v in batched.items(): | |
| if k not in ('key', 'input_ids', 'similarity'): | |
| batched[k] = torch.stack(v, dim=0) | |
| return batched | |
| class ChatDataset(IterableDataset): | |
| def __init__(self, | |
| data_path, | |
| rank: int = 0, | |
| world_size: int = 1, | |
| shuffle=True, | |
| repeat=True, | |
| buffer_size=1000, | |
| max_length=8000, | |
| num_workers=1, | |
| tokenizer=None): | |
| super().__init__() | |
| self.files = sorted(glob.glob(data_path)) | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.shuffle = shuffle | |
| self.repeat = repeat | |
| self.buffer_size = buffer_size | |
| self.max_length = max_length | |
| self.num_workers = num_workers | |
| self.tokenizer = tokenizer | |
| self.files = self.files[self.rank::self.world_size] | |
| def read_parquet_file(self, file_path): | |
| table = pq.read_table(file_path, columns=["content"]) | |
| df = table.to_pandas() | |
| for _, row in df.iterrows(): | |
| yield {"content": row["content"]} | |
| def __iter__(self): | |
| while True: | |
| file_list = self.files | |
| if self.shuffle: | |
| random.shuffle(file_list) | |
| for file in file_list: | |
| data_generator = self.read_parquet_file(file) | |
| buffer = [] | |
| for data in data_generator: | |
| text = data["content"] | |
| if self.tokenizer is None: | |
| if len(text) > self.max_length: | |
| start_index = random.randint(0, len(text) - self.max_length - 1) | |
| selected_text = text[start_index:start_index + self.max_length] | |
| else: | |
| selected_text = text | |
| else: | |
| if len(self.tokenizer(text)['input_ids']) < self.max_length: | |
| selected_text = text | |
| else: | |
| continue | |
| buffer.append({"input_ids": selected_text}) | |
| if len(buffer) >= self.buffer_size: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for item in buffer: | |
| yield item | |
| buffer = [] | |
| if buffer: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for item in buffer: | |
| yield item | |
| if not self.repeat: | |
| break | |
| def collate_fn(self, batch): | |
| batched = collections.defaultdict(list) | |
| for data in batch: | |
| for k, v in data.items(): | |
| batched[k].append(v) | |
| for k, v in batched.items(): | |
| if k not in ('key', 'input_ids', 'similarity'): | |
| batched[k] = torch.stack(v, dim=0) | |
| return batched | |
| class R2iDataset(IterableDataset): | |
| def __init__(self, | |
| data_path, | |
| rank: int = 0, | |
| world_size: int = 1, | |
| shuffle=True, | |
| repeat=True, | |
| buffer_size=1000, | |
| max_length=8000, | |
| num_workers=1, | |
| resolution=256, | |
| tokenizer=None): | |
| super().__init__() | |
| self.data_path = data_path | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.shuffle = shuffle | |
| self.repeat = repeat | |
| self.buffer_size = buffer_size | |
| self.max_length = max_length | |
| self.num_workers = num_workers | |
| self.tokenizer = tokenizer | |
| self.resolution = resolution | |
| def __iter__(self): | |
| while True: | |
| subdirs = sorted([d for d in glob.glob(os.path.join(self.data_path, "*")) if os.path.isdir(d)]) | |
| if self.shuffle: | |
| random.shuffle(subdirs) | |
| subdirs = subdirs[self.rank::self.world_size] | |
| subdirs = ['/data_storage/lbw/datasets/laion-aesthetics-12m-images-2/00000'] | |
| for subdir in subdirs: | |
| all_files = glob.glob(os.path.join(subdir, "*.*")) | |
| base_names = set() | |
| for file_path in all_files: | |
| base_name = os.path.splitext(os.path.basename(file_path))[0] | |
| base_names.add(base_name) | |
| base_names = list(base_names) | |
| if self.shuffle: | |
| random.shuffle(base_names) | |
| buffer = [] | |
| for base_name in base_names: | |
| jpg_path = os.path.join(subdir, f"{base_name}.jpg") | |
| caption_path = os.path.join(subdir, f"{base_name}.caption") | |
| shortcaption_path = os.path.join(subdir, f"{base_name}.shortcaption") | |
| if not os.path.exists(jpg_path): | |
| continue | |
| try: | |
| image = Image.open(jpg_path).convert("RGB") | |
| caption = "" | |
| if os.path.exists(caption_path): | |
| with open(caption_path, "r", encoding="utf-8") as f: | |
| caption = f.read().strip() | |
| short_caption = "" | |
| if os.path.exists(shortcaption_path): | |
| with open(shortcaption_path, "r", encoding="utf-8") as f: | |
| short_caption = f.read().strip() | |
| transformed_image = image_transform_clip({"images": image}, resolution=self.resolution)["images"] | |
| if self.tokenizer is not None: | |
| if len(self.tokenizer(caption)['input_ids']) > self.max_length - 2: | |
| continue | |
| prompt = ( | |
| '<|start_header_id|>user<|end_header_id|>\n' | |
| "You should first think out a more detailed version of the description and then provide the user with the image. The detailed description is enclosed within <think> </think> tags, i.e. <think> detailed description here </think> image here\n" | |
| f"{short_caption}" | |
| '<eot_id><|start_header_id|>assistant<|end_header_id|>\n' | |
| f"<think>{caption}</think>" | |
| ) | |
| sample = { | |
| "images": transformed_image, | |
| "input_ids": prompt, | |
| } | |
| buffer.append(sample) | |
| if len(buffer) >= self.buffer_size: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for item in buffer: | |
| yield item | |
| buffer = [] | |
| except Exception as e: | |
| print(f"Error processing {jpg_path}: {e}") | |
| continue | |
| if buffer: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for item in buffer: | |
| yield item | |
| if not self.repeat: | |
| break | |
| def collate_fn(self, batch): | |
| batched = collections.defaultdict(list) | |
| for data in batch: | |
| for k, v in data.items(): | |
| batched[k].append(v) | |
| for k, v in batched.items(): | |
| if k not in ('key', 'input_ids', 'similarity'): | |
| batched[k] = torch.stack(v, dim=0) | |
| return batched | |
| class VQADataset(IterableDataset): | |
| def __init__(self, | |
| json_path: str, | |
| image_root: str, | |
| tokenizer = None, | |
| rank: int = 0, | |
| world_size: int = 1, | |
| shuffle: bool = True, | |
| repeat: bool = True, | |
| buffer_size: int = 100, | |
| resolution: int = 256, | |
| max_length: int = 8000, | |
| num_workers: int = 1, | |
| image_transform_method: str = "squash"): | |
| super().__init__() | |
| self.json_path = json_path | |
| self.image_root = image_root | |
| self.tokenizer = tokenizer | |
| self.rank = rank | |
| self.world_size = world_size | |
| self.shuffle = shuffle | |
| self.repeat = repeat | |
| self.buffer_size = buffer_size | |
| self.resolution = resolution | |
| self.max_length = max_length | |
| self.num_workers = num_workers | |
| self.image_transform_method = image_transform_method | |
| try: | |
| with open(self.json_path, 'r', encoding='utf-8') as f: | |
| raw_data = json.load(f) | |
| except FileNotFoundError: | |
| print(f"Error: Data file not found at {self.json_path}") | |
| self.list_data_dict = [] | |
| except json.JSONDecodeError: | |
| print(f"Error: Could not decode JSON from {self.json_path}") | |
| self.list_data_dict = [] | |
| else: | |
| self.list_data_dict = [item for item in raw_data if 'image' in item and 'conversations' in item] | |
| self.list_data_dict = self.list_data_dict[self.rank::self.world_size] | |
| def __iter__(self): | |
| sot_token = '<|startoftext|>' | |
| assistant_prompt_suffix = '<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n' | |
| while True: | |
| current_data_list = list(self.list_data_dict) | |
| if self.shuffle: | |
| random.shuffle(current_data_list) | |
| buffer = [] | |
| for item in current_data_list: | |
| image_relative_path = item.get('image') | |
| conversations = item.get('conversations', []) | |
| if not image_relative_path or not conversations or len(conversations) < 2: | |
| continue | |
| num_total_messages = len(conversations) | |
| if num_total_messages % 2 != 0: | |
| conversations = conversations[:-1] | |
| num_total_messages -= 1 | |
| if num_total_messages < 2: continue | |
| num_turns = num_total_messages // 2 | |
| if num_turns == 0: | |
| continue | |
| selected_num_turns = random.randint(1, num_turns) | |
| selected_conversations = conversations[:selected_num_turns * 2] | |
| image_path = os.path.join(self.image_root, image_relative_path) | |
| try: | |
| image = Image.open(image_path).convert("RGB") | |
| if self.image_transform_method == "squash": | |
| transformed_image = image_transform_squash({"images": image}, resolution=self.resolution)["images"] | |
| elif self.image_transform_method == "pad": | |
| transformed_image = image_transform_pad({"images": image}, resolution=self.resolution)["images"] | |
| else: | |
| transformed_image = image_transform_clip({"images": image}, resolution=self.resolution)["images"] | |
| first_human_message = selected_conversations[0]['value'] | |
| processed_message = first_human_message.replace('<image>\n', '').replace('\n<image>', '') | |
| current_selection_messages = list(selected_conversations) | |
| current_selection_messages[0] = dict(current_selection_messages[0]) | |
| current_selection_messages[0]['value'] = processed_message | |
| messages = [] | |
| for turn in current_selection_messages: | |
| role = "user" if turn["from"] == "human" else "assistant" | |
| messages.append({"role": role, "content": turn["value"]}) | |
| formatted_text = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| if formatted_text.startswith(sot_token): | |
| formatted_text = formatted_text[len(sot_token):] | |
| if formatted_text.endswith(assistant_prompt_suffix): | |
| formatted_text = formatted_text[:-len(assistant_prompt_suffix)] | |
| token_ids = self.tokenizer(formatted_text)['input_ids'] | |
| if len(token_ids) > self.max_length: | |
| continue | |
| sample = { | |
| "images": transformed_image, | |
| "input_ids": formatted_text, | |
| } | |
| buffer.append(sample) | |
| if len(buffer) >= self.buffer_size: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for buf_item in buffer: | |
| yield buf_item | |
| buffer = [] | |
| except FileNotFoundError: | |
| print(f"Warning: Image file not found at {image_path}, skipping item.") | |
| continue | |
| except Exception as e: | |
| print(f"Warning: Error processing item with image {image_path}: {e}, skipping.") | |
| continue | |
| if buffer: | |
| if self.shuffle: | |
| random.shuffle(buffer) | |
| for buf_item in buffer: | |
| yield buf_item | |
| if not self.repeat: | |
| break | |
| def collate_fn(self, batch): | |
| batched = collections.defaultdict(list) | |
| for data in batch: | |
| for k, v in data.items(): | |
| batched[k].append(v) | |
| for k, v in batched.items(): | |
| if k not in ('key', 'input_ids', 'similarity'): | |
| batched[k] = torch.stack(v, dim=0) | |
| return batched | |
| def image_transform_clip(sample, resolution=256): | |
| image = sample["images"] | |
| image = transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BICUBIC)(image) | |
| image = transforms.CenterCrop((resolution, resolution))(image) | |
| image = transforms.ToTensor()(image) | |
| image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) | |
| sample["images"] = image | |
| return sample | |
| def image_transform_squash(sample, resolution=256): | |
| image = sample["images"] | |
| image = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC)(image) | |
| image = transforms.ToTensor()(image) | |
| image = transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])(image) | |
| sample["images"] = image | |
| return sample | |
| def image_transform_pad(sample, resolution=256, fill_color=(255, 255, 255)): | |
| image = sample["images"] | |
| w, h = image.size | |
| if w == h: | |
| padded_image = image | |
| elif w < h: | |
| padding_needed = h - w | |
| padding_left = padding_needed // 2 | |
| padding_right = padding_needed - padding_left | |
| pad_transform = transforms.Pad((padding_left, 0, padding_right, 0), fill=fill_color, padding_mode='constant') | |
| padded_image = pad_transform(image) | |
| else: | |
| padding_needed = w - h | |
| padding_top = padding_needed // 2 | |
| padding_bottom = padding_needed - padding_top | |
| pad_transform = transforms.Pad((0, padding_top, 0, padding_bottom), fill=fill_color, padding_mode='constant') | |
| padded_image = pad_transform(image) | |
| image_resized = transforms.Resize((resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC)(padded_image) | |
| image_tensor = transforms.ToTensor()(image_resized) | |
| image_normalized = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])(image_tensor) | |
| sample["images"] = image_normalized | |
| return sample | |
| if __name__ == '__main__': | |
| data_path = "/data_storage/shared/datasets/falcon-refinedweb/data/data/*.parquet" | |
| dataset = RefinedWebDataset( | |
| data_path=data_path, | |
| max_length=8000, | |
| buffer_size=0, | |
| ) | |
| from torch.utils.data import DataLoader | |
| train_dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| sampler=None, | |
| collate_fn=dataset.collate_fn, | |
| num_workers=0 | |
| ) | |
| print("Starting data loading test...") | |
| for i, batch in enumerate(train_dataloader): | |
| if i == 0: | |
| print(batch) | |
| print(f"Batch size: {len(batch['input_ids'])}") | |
| print(f"First sample length: {len(batch['input_ids'][0])}") | |
| if i >= 5: | |
| break | |
| print("Data loading test complete") |