sync token_dictionary variable name w/ classifier
Browse files- geneformer/classifier.py +4 -4
geneformer/classifier.py
CHANGED
|
@@ -1065,11 +1065,11 @@ class Classifier:
|
|
| 1065 |
# define the data collator
|
| 1066 |
if self.classifier == "cell":
|
| 1067 |
data_collator = DataCollatorForCellClassification(
|
| 1068 |
-
token_dictionary=self.
|
| 1069 |
)
|
| 1070 |
elif self.classifier == "gene":
|
| 1071 |
data_collator = DataCollatorForGeneClassification(
|
| 1072 |
-
token_dictionary=self.
|
| 1073 |
)
|
| 1074 |
|
| 1075 |
# define function to initiate model
|
|
@@ -1242,11 +1242,11 @@ class Classifier:
|
|
| 1242 |
# define the data collator
|
| 1243 |
if self.classifier == "cell":
|
| 1244 |
data_collator = DataCollatorForCellClassification(
|
| 1245 |
-
token_dictionary=self.
|
| 1246 |
)
|
| 1247 |
elif self.classifier == "gene":
|
| 1248 |
data_collator = DataCollatorForGeneClassification(
|
| 1249 |
-
token_dictionary=self.
|
| 1250 |
)
|
| 1251 |
|
| 1252 |
# create the trainer
|
|
|
|
| 1065 |
# define the data collator
|
| 1066 |
if self.classifier == "cell":
|
| 1067 |
data_collator = DataCollatorForCellClassification(
|
| 1068 |
+
token_dictionary=self.gene_token_dict
|
| 1069 |
)
|
| 1070 |
elif self.classifier == "gene":
|
| 1071 |
data_collator = DataCollatorForGeneClassification(
|
| 1072 |
+
token_dictionary=self.gene_token_dict
|
| 1073 |
)
|
| 1074 |
|
| 1075 |
# define function to initiate model
|
|
|
|
| 1242 |
# define the data collator
|
| 1243 |
if self.classifier == "cell":
|
| 1244 |
data_collator = DataCollatorForCellClassification(
|
| 1245 |
+
token_dictionary=self.gene_token_dict
|
| 1246 |
)
|
| 1247 |
elif self.classifier == "gene":
|
| 1248 |
data_collator = DataCollatorForGeneClassification(
|
| 1249 |
+
token_dictionary=self.gene_token_dict
|
| 1250 |
)
|
| 1251 |
|
| 1252 |
# create the trainer
|