Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from typing import List, Optional | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast | |
| from mario_gpt.level import FULL_LEVEL_STR_WITH_PATHS | |
| DEFAULT_MODEL = "distilgpt2" | |
| def split_given_size(a, size): | |
| return np.split(a, np.arange(size, len(a), size)) | |
| def flip_and_transpose(arr: np.array, flip_first: bool = False): | |
| if arr.shape[-1] > 1: | |
| if flip_first: | |
| return np.flip(arr, -1).transpose() | |
| return np.flip(arr.transpose(), -1) | |
| return arr | |
| def join_list_of_list(str_lists): | |
| return ["".join(s) for s in str_lists] | |
| def characterize(str_lists): | |
| return [list(s) for s in str_lists] | |
| class MarioDataset(Dataset): | |
| def __init__( | |
| self, | |
| tokenizer: Optional[PreTrainedTokenizer] = None, | |
| level_string: Optional[str] = None, | |
| context_len: int = 700, | |
| height: int = 14, | |
| remove_start_end_tokens: bool = False, | |
| sample_all_indices: bool = False, | |
| ): | |
| if level_string is None: | |
| print( | |
| "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS..." | |
| ) | |
| level_string = FULL_LEVEL_STR_WITH_PATHS | |
| elif ".txt" in level_string: | |
| with open(level_string, "r") as file: | |
| level_string = file.read() | |
| self.character_set = set(level_string) | |
| if "\n" in self.character_set: | |
| self.character_set.remove("\n") | |
| self.vocab_size = len(self.character_set) | |
| self.sample_all_indices = sample_all_indices | |
| def get_training_corpus(): | |
| yield list(level_string) | |
| if tokenizer is None: | |
| tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) | |
| self.tokenizer = tokenizer | |
| if getattr(tokenizer, "train_new_from_iterator", None) is not None: | |
| self.tokenizer = tokenizer.train_new_from_iterator( | |
| get_training_corpus(), 52000 | |
| ) | |
| elif getattr(tokenizer, "train_from_iterator", None) is not None: | |
| self.tokenizer = PreTrainedTokenizerFast(tokenizer_object=self.tokenizer) | |
| self.tokenizer = self.tokenizer.train_new_from_iterator( | |
| get_training_corpus(), self.vocab_size | |
| ) | |
| self.context_len = context_len | |
| self.height = height | |
| x, self.str_arr = self.convert_level_to_tensor(level_string.split("\n")) | |
| self.input_ids = x["input_ids"].squeeze() | |
| self.attention_masks = x["attention_mask"].squeeze() | |
| if remove_start_end_tokens: | |
| self.input_ids = self.input_ids[1:-1] | |
| self.attention_masks = self.attention_masks[1:-1] | |
| self.indices = self.generate_indices() | |
| self.unique_tokens, self.unique_counts = self.input_ids.unique( | |
| return_counts=True | |
| ) | |
| self.weighted_unique_counts = ( | |
| 1.0 / self.unique_counts / torch.sum(self.unique_counts) | |
| ) | |
| self.token_dict = {} | |
| string_tokens = list(self.tokenizer.decode(self.unique_tokens)) | |
| for int_token, string_token in zip(self.unique_tokens, string_tokens): | |
| self.token_dict[string_token] = int_token | |
| def convert_level_to_tensor(self, level: List[str]): | |
| str_arr = flip_and_transpose(np.array(characterize(level))) | |
| str_arr = "".join(join_list_of_list(str_arr)) | |
| x = self.tokenizer(str_arr, return_tensors="pt") | |
| return x, str_arr | |
| def __len__(self): | |
| return self.indices.shape[0] | |
| def __getitem__(self, idx): | |
| indices = self.indices[idx] | |
| return self.input_ids[indices], self.attention_masks[indices] | |
| def generate_indices(self): | |
| out = [] | |
| for idx in range(self.input_ids.shape[0] - self.context_len): | |
| if idx % self.height == 0 or self.sample_all_indices: | |
| arange = torch.arange(idx, idx + self.context_len) | |
| out.append(arange) | |
| return torch.stack(out) | |
| def sample_indices(self, batch_size): | |
| out = [] | |
| for _ in range(batch_size): | |
| start_idx = np.random.randint(0, self.__len__() - self.context_len) | |
| indices = torch.arange(start_idx, start_idx + self.context_len) | |
| out.append(indices) | |
| return torch.stack(out) | |
| def __str__(self): | |
| str_list = characterize(self.tokenizer.batch_decode(self.x["input_ids"])) | |
| string = "\n".join( | |
| join_list_of_list(flip_and_transpose(np.array(str_list), True)) | |
| ) | |
| return string | |
| def generate_mask(self, mask_len: int, batch_size: int = 1): | |
| mask_token = self.tokenizer("<mask>").input_ids[1] | |
| ones = torch.ones((batch_size, mask_len)) | |
| return ones * mask_token | |
| def apply_mask(self, level, masked_indices, mask=None): | |
| if len(level.shape) == 1: | |
| level = level.unsqueeze(0) | |
| batch_size = level.shape[0] | |
| mask_len = masked_indices.shape[-1] | |
| if mask is None: | |
| mask = self.generate_mask(mask_len, batch_size) | |
| mask = mask.long().to(level.device) | |
| masked_level = level * torch.ones_like(level).to(level.device) | |
| masked_level[:, masked_indices] = mask | |
| return masked_level | |