sky-2002 commited on
Commit
55878ce
·
verified ·
1 Parent(s): fc96ac7

Upload deepseek_tinystories/processor.py

Browse files
Files changed (1) hide show
  1. deepseek_tinystories/processor.py +183 -0
deepseek_tinystories/processor.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ import os
3
+ import numpy as np
4
+ from datasets import load_dataset
5
+ from tqdm.auto import tqdm
6
+ import torch
7
+ from typing import List
8
+
9
+
10
+ class TinyStoriesProcesssor:
11
+
12
+ def __init__(self, tokenizer_name: str = "gpt2", max_length: int = 1024):
13
+ self.tokenizer = tiktoken.get_encoding(tokenizer_name)
14
+ self.max_length = max_length
15
+
16
+ self.data_dir = os.path.join(
17
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data"
18
+ )
19
+ os.makedirs(self.data_dir, exist_ok=True)
20
+ print(f"Data directory: {self.data_dir}")
21
+
22
+ def tokenize(self, text: str) -> List[int]:
23
+ tokens = self.tokenizer.encode(text)
24
+ if len(tokens) > self.max_length:
25
+ tokens = tokens[: self.max_length]
26
+ return tokens
27
+
28
+ def detokenize(self, tokens: List[int]) -> str:
29
+ return self.tokenizer.decode(tokens)
30
+
31
+ def process(self, example):
32
+ text = example["text"]
33
+ tokens = self.tokenize(text)
34
+ return {"input_ids": tokens, "len": len(tokens)}
35
+
36
+ def prepare_dataset(
37
+ self,
38
+ dataset_name: str = "roneneldan/TinyStories",
39
+ split: str = "train",
40
+ debug: bool = False,
41
+ ):
42
+ train_path = os.path.join(self.data_dir, "train.bin")
43
+ validation_path = os.path.join(self.data_dir, "val.bin")
44
+ test_path = os.path.join(self.data_dir, "test.bin")
45
+
46
+ ds = load_dataset(dataset_name, split=split)
47
+
48
+ if debug:
49
+ print("Debug mode: using a small subset of the data")
50
+ ds = ds.select(range(1024))
51
+
52
+ if (
53
+ os.path.exists(train_path)
54
+ and os.path.exists(validation_path)
55
+ and os.path.exists(test_path)
56
+ ):
57
+
58
+ print("Found existing processed files!")
59
+ print(f"Train file: {os.path.getsize(train_path) / (1024*1024):.2f} MB")
60
+ print(
61
+ f"Validation file: {os.path.getsize(validation_path) / (1024*1024):.2f} MB"
62
+ )
63
+ print(f"Finetune file: {os.path.getsize(test_path) / (1024*1024):.2f} MB")
64
+
65
+ return {
66
+ "train": train_path,
67
+ "validation": validation_path,
68
+ "finetune": test_path,
69
+ }
70
+
71
+ train_val_test = ds.train_test_split(test_size=0.2, seed=42)
72
+ val_finetune = train_val_test["test"].train_test_split(test_size=0.5, seed=42)
73
+
74
+ # Create a new dataset dictionary with all splits
75
+ ds = {
76
+ "train": train_val_test["train"],
77
+ "validation": val_finetune["train"],
78
+ "test": val_finetune["test"],
79
+ }
80
+
81
+ for split_name, split_data in ds.items():
82
+ print(f"\nProcessing {split_name} split...")
83
+
84
+ # Process the data
85
+ tokenized = split_data.map(
86
+ self.process,
87
+ desc=f"tokenizing {split_name} split",
88
+ num_proc=8,
89
+ )
90
+
91
+ tokenized = tokenized.filter(lambda x: x["len"] > 0)
92
+ print(f"After processing: {len(tokenized)} valid examples")
93
+
94
+ filename = os.path.join(self.data_dir, f"{split_name}.bin")
95
+ print(f"Saving {split_name} split to: {filename}")
96
+
97
+ arr_len = np.sum(tokenized["len"], dtype=np.uint64)
98
+ dtype = np.uint16
99
+ arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
100
+ total_batches = 1024
101
+
102
+ idx = 0
103
+ for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
104
+ batch = tokenized.shard(
105
+ num_shards=total_batches, index=batch_idx, contiguous=True
106
+ ).with_format("numpy")
107
+ arr_batch = np.concatenate(batch["input_ids"])
108
+ arr[idx : idx + len(arr_batch)] = arr_batch
109
+ idx += len(arr_batch)
110
+ arr.flush()
111
+
112
+ if os.path.exists(filename):
113
+ print(f"Successfully created {filename}")
114
+ print(f"File size: {os.path.getsize(filename) / (1024*1024):.2f} MB")
115
+ else:
116
+ raise RuntimeError(f"Failed to create {filename}")
117
+
118
+ return {
119
+ "train": train_path,
120
+ "validation": validation_path,
121
+ "test": test_path,
122
+ }
123
+
124
+ def load_binary_data(self, filepath: str) -> torch.Tensor:
125
+ """Load binary data file as tensor"""
126
+ try:
127
+ data = np.memmap(filepath, dtype=np.uint16, mode="r")
128
+ return torch.from_numpy(data.copy())
129
+ except Exception as e:
130
+ print(f"Error loading data from {filepath}: {e}")
131
+ raise
132
+
133
+ def get_batch(self, data: torch.Tensor, batch_size: int, block_size: int) -> tuple:
134
+ """Get a batch of data for training"""
135
+
136
+ ix = torch.randint(len(data) - block_size, (batch_size,))
137
+
138
+ x = torch.stack([data[i : i + block_size].long() for i in ix])
139
+ y = torch.stack([data[i + 1 : i + 1 + block_size].long() for i in ix])
140
+
141
+ return x, y
142
+
143
+ def prepare_dataset_memory(
144
+ self,
145
+ dataset_name: str = "roneneldan/TinyStories",
146
+ debug: bool = False,
147
+ splits: List[str] = ["train", "validation", "test"],
148
+ ):
149
+ """Load, tokenize, and keep dataset fully in memory."""
150
+ print("Loading dataset into memory...")
151
+ ds = load_dataset(dataset_name)
152
+
153
+ if debug:
154
+ print("Debug mode: using a small subset of the data")
155
+ for split in ds:
156
+ ds[split] = ds[split].select(range(min(10240, len(ds[split]))))
157
+
158
+ for split in splits:
159
+ print(f"\nProcessing {split} split (in memory)...")
160
+ tokenized = ds[split].map(
161
+ self.process,
162
+ desc=f"tokenizing {split} split",
163
+ )
164
+ tokenized = tokenized.filter(lambda x: x["len"] > 0)
165
+ print(f"After processing: {len(tokenized)} valid examples")
166
+
167
+ # Flatten into one long array of token IDs
168
+ arr = np.concatenate(tokenized["input_ids"])
169
+ arr = torch.tensor(arr, dtype=torch.long)
170
+ self.memory_datasets[split] = arr
171
+
172
+ return self.memory_datasets
173
+
174
+ def get_dataset(self, split: str = "train") -> torch.Tensor:
175
+ """Return in-memory dataset tensor for a split."""
176
+ if split not in self.memory_datasets:
177
+ raise ValueError(f"Split {split} not found. Call prepare_dataset_memory first.")
178
+ return self.memory_datasets[split]
179
+
180
+
181
+ if __name__ == "__main__":
182
+ processor = TinyStoriesProcesssor(tokenizer_name="gpt2", max_length=512)
183
+ processor.prepare_dataset(split="train", debug=True)