Add option for variable input_size and to add CLS/SEP Tokens (#299)
Browse files- Add option for variable input_size and to add CLS/SEP Tokens (944c98c4ed4fa9bd539c2a7d8db5351161d85012)
Co-authored-by: Han Chen <[email protected]>
- geneformer/tokenizer.py +22 -8
geneformer/tokenizer.py
CHANGED
|
@@ -81,14 +81,14 @@ class TranscriptomeTokenizer:
|
|
| 81 |
custom_attr_name_dict=None,
|
| 82 |
nproc=1,
|
| 83 |
chunk_size=512,
|
|
|
|
|
|
|
| 84 |
gene_median_file=GENE_MEDIAN_FILE,
|
| 85 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 86 |
):
|
| 87 |
"""
|
| 88 |
Initialize tokenizer.
|
| 89 |
-
|
| 90 |
**Parameters:**
|
| 91 |
-
|
| 92 |
custom_attr_name_dict : None, dict
|
| 93 |
| Dictionary of custom attributes to be added to the dataset.
|
| 94 |
| Keys are the names of the attributes in the loom file.
|
|
@@ -97,6 +97,10 @@ class TranscriptomeTokenizer:
|
|
| 97 |
| Number of processes to use for dataset mapping.
|
| 98 |
chunk_size: int = 512
|
| 99 |
| Chunk size for anndata tokenizer.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
gene_median_file : Path
|
| 101 |
| Path to pickle file containing dictionary of non-zero median
|
| 102 |
| gene expression values across Genecorpus-30M.
|
|
@@ -112,6 +116,12 @@ class TranscriptomeTokenizer:
|
|
| 112 |
# chunk size for anndata tokenizer
|
| 113 |
self.chunk_size = chunk_size
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
# load dictionary of gene normalization factors
|
| 116 |
# (non-zero median value of expression across Genecorpus-30M)
|
| 117 |
with open(gene_median_file, "rb") as f:
|
|
@@ -137,9 +147,7 @@ class TranscriptomeTokenizer:
|
|
| 137 |
):
|
| 138 |
"""
|
| 139 |
Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
|
| 140 |
-
|
| 141 |
**Parameters:**
|
| 142 |
-
|
| 143 |
data_directory : Path
|
| 144 |
| Path to directory containing loom files or anndata files
|
| 145 |
output_directory : Path
|
|
@@ -324,7 +332,7 @@ class TranscriptomeTokenizer:
|
|
| 324 |
file_cell_metadata[k] += subview.ca[k].tolist()
|
| 325 |
else:
|
| 326 |
file_cell_metadata = None
|
| 327 |
-
|
| 328 |
return tokenized_cells, file_cell_metadata
|
| 329 |
|
| 330 |
def create_dataset(
|
|
@@ -357,8 +365,14 @@ class TranscriptomeTokenizer:
|
|
| 357 |
example["input_ids_uncropped"] = example["input_ids"]
|
| 358 |
example["length_uncropped"] = len(example["input_ids"])
|
| 359 |
|
| 360 |
-
# Truncate/Crop input_ids to size
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
example["length"] = len(example["input_ids"])
|
| 363 |
|
| 364 |
return example
|
|
@@ -366,4 +380,4 @@ class TranscriptomeTokenizer:
|
|
| 366 |
output_dataset_truncated = output_dataset.map(
|
| 367 |
format_cell_features, num_proc=self.nproc
|
| 368 |
)
|
| 369 |
-
return output_dataset_truncated
|
|
|
|
| 81 |
custom_attr_name_dict=None,
|
| 82 |
nproc=1,
|
| 83 |
chunk_size=512,
|
| 84 |
+
input_size=2048,
|
| 85 |
+
special_token=False,
|
| 86 |
gene_median_file=GENE_MEDIAN_FILE,
|
| 87 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 88 |
):
|
| 89 |
"""
|
| 90 |
Initialize tokenizer.
|
|
|
|
| 91 |
**Parameters:**
|
|
|
|
| 92 |
custom_attr_name_dict : None, dict
|
| 93 |
| Dictionary of custom attributes to be added to the dataset.
|
| 94 |
| Keys are the names of the attributes in the loom file.
|
|
|
|
| 97 |
| Number of processes to use for dataset mapping.
|
| 98 |
chunk_size: int = 512
|
| 99 |
| Chunk size for anndata tokenizer.
|
| 100 |
+
input_size: int = 2048
|
| 101 |
+
| Input size for tokenization
|
| 102 |
+
special_token: bool = False
|
| 103 |
+
| Option to add CLS and SEP tokens
|
| 104 |
gene_median_file : Path
|
| 105 |
| Path to pickle file containing dictionary of non-zero median
|
| 106 |
| gene expression values across Genecorpus-30M.
|
|
|
|
| 116 |
# chunk size for anndata tokenizer
|
| 117 |
self.chunk_size = chunk_size
|
| 118 |
|
| 119 |
+
# input size for tokenization
|
| 120 |
+
self.input_size = input_size
|
| 121 |
+
|
| 122 |
+
# add CLS and SEP tokens
|
| 123 |
+
self.special_token = special_token
|
| 124 |
+
|
| 125 |
# load dictionary of gene normalization factors
|
| 126 |
# (non-zero median value of expression across Genecorpus-30M)
|
| 127 |
with open(gene_median_file, "rb") as f:
|
|
|
|
| 147 |
):
|
| 148 |
"""
|
| 149 |
Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
|
|
|
|
| 150 |
**Parameters:**
|
|
|
|
| 151 |
data_directory : Path
|
| 152 |
| Path to directory containing loom files or anndata files
|
| 153 |
output_directory : Path
|
|
|
|
| 332 |
file_cell_metadata[k] += subview.ca[k].tolist()
|
| 333 |
else:
|
| 334 |
file_cell_metadata = None
|
| 335 |
+
|
| 336 |
return tokenized_cells, file_cell_metadata
|
| 337 |
|
| 338 |
def create_dataset(
|
|
|
|
| 365 |
example["input_ids_uncropped"] = example["input_ids"]
|
| 366 |
example["length_uncropped"] = len(example["input_ids"])
|
| 367 |
|
| 368 |
+
# Truncate/Crop input_ids to input size
|
| 369 |
+
if tk.special_token:
|
| 370 |
+
example["input_ids"] = example["input_ids"][0:self.input_size-2] # truncate to leave space for CLS and SEP token
|
| 371 |
+
example["input_ids"] = np.insert(example["input_ids"], 0, self.gene_token_dict.get("<cls>"))
|
| 372 |
+
example["input_ids"] = np.insert(example["input_ids"], len(example["input_ids"]), self.gene_token_dict.get("<sep>"))
|
| 373 |
+
else:
|
| 374 |
+
# Truncate/Crop input_ids to input size
|
| 375 |
+
example["input_ids"] = example["input_ids"][0:self.input_size]
|
| 376 |
example["length"] = len(example["input_ids"])
|
| 377 |
|
| 378 |
return example
|
|
|
|
| 380 |
output_dataset_truncated = output_dataset.map(
|
| 381 |
format_cell_features, num_proc=self.nproc
|
| 382 |
)
|
| 383 |
+
return output_dataset_truncated
|